Skip to content

feat(actor): expose per-step GSP prediction loss via last_gsp_loss#23

Merged
jdbloom merged 2 commits intomainfrom
feature/gsp-loss-onto-main
Apr 13, 2026
Merged

feat(actor): expose per-step GSP prediction loss via last_gsp_loss#23
jdbloom merged 2 commits intomainfrom
feature/gsp-loss-onto-main

Conversation

@jdbloom
Copy link
Copy Markdown
Contributor

@jdbloom jdbloom commented Apr 13, 2026

Summary

Cherry-picks the work originally landed in PR #22 (`feat/rddpg-lstm-fix` branch) onto main, now that PR #21 (NaN detection) has merged.

  • `Actor.last_gsp_loss` attribute populated by `learn_gsp()` each tick, reset to None at the start of `learn()` so callers can distinguish "no GSP step ran" from "GSP step ran"
  • `learn_TD3` signature now accepts `recurrent=False` matching DDPG/RDDPG (was a latent 3-arg call from learn_gsp)
  • TD3's non-actor-update tuple path skips recording instead of writing a spurious 0.0
  • Docstring on `last_gsp_loss` clarifies it's the GSP learner's training loss (actor loss for DDPG/RDDPG/TD3, MSE for attention) — for prediction-collapse detection consumers should rely on raw squared error captured in RL-CollectiveTransport

Why

The information-collapse diagnostic (see Stelaris `docs/specs/2026-04-12-dispatcher-diagnostic-batch.md`) needs to log the GSP prediction network's training loss per learning step. Until this PR, only the actor/critic loss was returned from `Actor.learn()`, so a fully collapsed GSP head would look completely normal in the logs.

Test plan

  • 4 new tests in `tests/test_actor/test_gsp_loss_exposure.py` (TDD — watched each fail first):
    • `last_gsp_loss` initialized to None
    • Populated after learn() with GSP enabled + filled buffers
    • Resets between ticks (the load-bearing invariant)
    • Stays None when GSP is disabled
  • Full actor + learning_aids suite: 70/70 pass on top of NaN detection (PR feat: add NaN/Inf detection in learning pipeline #21)

Provenance

This is a clean re-application of commits `e1b138d` and `e7f0a55` from PR #22, which landed on `feat/rddpg-lstm-fix` (a feature branch, not main). The PR #22 review found this work approved-with-non-blocking-concerns and a follow-up commit addressed all of them. A second-pass review on the fix commit also approved.

🤖 Generated with Claude Code

Joshua Bloom and others added 2 commits April 13, 2026 06:31
The GSP prediction network's training loss was never surfaced through
Actor.learn(). Only the actor/critic loss was returned, which stays normal
even when the GSP head collapses to a near-constant output. Add
last_gsp_loss attribute populated by learn_gsp() whenever a GSP learning
step fires, reset to None at the start of each learn() call so callers can
distinguish "no GSP step this tick" from "GSP step ran".

Needed for the information-collapse diagnostic (see Stelaris
docs/specs/2026-04-12-dispatcher-diagnostic-batch.md) — without it we
cannot tell whether non-recurrent GSP variants are learning or degenerate.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…reset test

- learn_TD3 now accepts recurrent=False to match the DDPG/RDDPG signatures;
  the learn_gsp dispatch was passing 3 positional args to a 2-arg method.
  Latent bug today (GSP networks are built as DDPG/attention, not TD3) but
  removes the footgun before the diagnostic batch exercises TD3 variants.
- TD3's non-actor-update step returns (0, 0); previously we unwrapped to
  0.0 and logged it. That produces false collapse signals every
  update_actor_iter-1 ticks. Now we skip the entry entirely — leave
  last_gsp_loss at None as if no GSP step ran.
- Doc the semantic: last_gsp_loss is the GSP learner's training loss,
  which is actor loss (policy-gradient signal) for DDPG/RDDPG/TD3 and
  genuine MSE only for attention. For prediction-collapse detection
  consumers should rely on gsp_squared_error and the HDF5Logger
  episode-level gsp_output_std / gsp_pred_target_corr attrs.
- Add reset-between-ticks test covering the load-bearing invariant that
  last_gsp_loss returns to None when a learn() call runs but no GSP
  learning step fires.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jdbloom jdbloom merged commit 2d67a56 into main Apr 13, 2026
3 of 4 checks passed
@jdbloom jdbloom deleted the feature/gsp-loss-onto-main branch April 13, 2026 10:32
jdbloom pushed a commit to NESTLab/RL-CollectiveTransport that referenced this pull request Apr 13, 2026
Adds the per-step and per-episode HDF5 fields needed to detect "GSP
information collapse" — a suspected failure mode where the GSP prediction
network collapses to a near-constant output that carries no information
about the collective state. See Stelaris
docs/specs/2026-04-12-dispatcher-diagnostic-batch.md for the hypothesis.

Changes (all gated on opt-in — backward compatible):

env.py:
- calculate_gsp_reward returns (reward, label, squared_errors). The raw
  per-robot (diff - prediction)^2 carries the magnitude that the clipped
  [-2, 0] reward hides.

rl_code/src/hdf5_logger.py:
- New optional kwargs gsp_target, gsp_squared_error on writerow → 2D
  (timesteps × robots) datasets.
- New record_gsp_loss(value) method → 1D dataset at GSP learning cadence.
- write_episode now computes two episode-level summary attrs when both
  prediction and target buffers are present:
  - gsp_output_std (collapse signature: → 0)
  - gsp_pred_target_corr (collapse signature: → NaN when std is below
    1e-12 tolerance, distinguishing "undefined" from "measured zero")
  Uses np.nanstd and pair-wise NaN masking so a single physics glitch
  doesn't poison the summary; raises ValueError if gsp_target/gsp_heading
  buffers desync within an episode.

Main.py:
- 3-tuple unpack of calculate_gsp_reward; broadcast scalar label to
  per-robot list for the (timesteps × robots) HDF5 schema; pass new
  kwargs to hdf5_writer.writerow.
- After each model.learn() call, capture model.last_gsp_loss (from
  GSP-RL PR #23) and pass to hdf5_writer.record_gsp_loss. In
  --independent_learning mode, aggregate across per-robot models to a
  single scalar per learn tick (mean) so the gsp_loss axis length stays
  num_learn_steps regardless of mode.

Tests:
- 6 new TestGSPSquaredErrorReturn cases in test_env/test_gsp_reward.py;
  existing tests updated to 3-tuple unpack.
- tests/test_diagnostics/test_hdf5_logger_gsp_diagnostics.py — 9 new
  tests covering: per-step datasets, gsp_loss recording, episode attrs,
  collapse signature detection, degenerate task, NaN poisoning,
  desynced-buffer raise, backward compat, optional record_gsp_loss.

Companion: NESTLab/GSP-RL#23 (Actor.last_gsp_loss).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant