Skip to content

feat(actor): expose per-step GSP prediction loss via last_gsp_loss#22

Merged
jdbloom merged 2 commits intofeat/rddpg-lstm-fixfrom
feature/hdf5-gsp-diagnostics
Apr 13, 2026
Merged

feat(actor): expose per-step GSP prediction loss via last_gsp_loss#22
jdbloom merged 2 commits intofeat/rddpg-lstm-fixfrom
feature/hdf5-gsp-diagnostics

Conversation

@jdbloom
Copy link
Copy Markdown
Contributor

@jdbloom jdbloom commented Apr 13, 2026

Summary

  • Adds Actor.last_gsp_loss attribute populated by learn_gsp() each time a GSP learning step fires
  • Reset to None at the start of each learn() call so callers can distinguish "no GSP step this tick" from "GSP step ran"
  • All four GSP-capable inner learn paths (DDPG, RDDPG, TD3, attention) surface their loss; TD3's (0, 0) edge case is normalized to a scalar

Why

The primary loss returned from Actor.learn() is the actor/critic loss, which stays completely normal even when the GSP prediction head has collapsed to a near-constant output. Without surfacing the GSP network's own training loss, we cannot diagnose the "information collapse" hypothesis called out in the Memory-Enhanced GSP paper outline (`Revamped Reward structure for GSP to prevent information collapse`). See companion PR in Stelaris / RL-CollectiveTransport (feature/hdf5-gsp-diagnostics).

Test plan

  • New test file `tests/test_actor/test_gsp_loss_exposure.py` (3 tests, all TDD — watched them fail first)
    • `last_gsp_loss` starts as None
    • Populated with a float after learn() with GSP enabled + filled buffers
    • Stays None when GSP is disabled
  • Full actor + learning_aids test suite: 69/69 pass
  • No behavior change for non-GSP actors

🤖 Generated with Claude Code

The GSP prediction network's training loss was never surfaced through
Actor.learn(). Only the actor/critic loss was returned, which stays normal
even when the GSP head collapses to a near-constant output. Add
last_gsp_loss attribute populated by learn_gsp() whenever a GSP learning
step fires, reset to None at the start of each learn() call so callers can
distinguish "no GSP step this tick" from "GSP step ran".

Needed for the information-collapse diagnostic (see Stelaris
docs/specs/2026-04-12-dispatcher-diagnostic-batch.md) — without it we
cannot tell whether non-recurrent GSP variants are learning or degenerate.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jdbloom
Copy link
Copy Markdown
Contributor Author

jdbloom commented Apr 13, 2026

Verdict: Approve with minor concerns (non-blocking)

The loss-capture path is correct, the reset timing is right, the TD3 tuple handling is sound, and there are no attribute collisions or Python-version issues (^3.10 in pyproject.toml supports float | None). Safe to merge for the diagnostic batch. A few things worth noting before you interpret the numbers downstream.

Concerns

1. Semantic mismatch between scheme branches (gsp_rl/src/actors/learning_aids.py:306,386,438,450). learn_DDPG/learn_RDDPG/learn_TD3 all return actor_loss.item() — the GSP actor loss, i.e. -critic(s, actor(s)).mean(). Only learn_attention returns a genuine prediction MSE against the label. A collapsed GSP prediction head in DDPG/RDDPG/TD3 variants will not necessarily show up as an anomalous last_gsp_loss, because the actor loss is a critic-derived policy-gradient signal, not a prediction error. If the collapse diagnostic depends on seeing prediction MSE specifically, this attribute is not sufficient for non-attention schemes — consider also exposing value_loss from the GSP critic, or better, a dedicated prediction-error probe computed outside learn_*. Flagging because the PR description and attribute name imply "GSP prediction network training loss," which is only literally true for attention.

2. TD3 (0, 0) path produces legitimate 0.0 readings (learning_aids.py:427-428). On non-actor-update TD3 steps, last_gsp_loss becomes 0.0 (critic ran, actor did not). A downstream collapse detector that flags small values will get false positives on update_actor_iter - 1 out of every update_actor_iter ticks. The PR normalization is fine; the consumer needs to know.

3. Pre-existing TD3 GSP signature bug (actor.py:451 vs learning_aids.py:388). Not introduced here, but surfaced by this PR. learn_gsp calls self.learn_TD3(self.gsp_networks, self.gsp, self.recurrent_gsp) (3 args) while learn_TD3(self, networks, gsp=False) takes 2. The TD3 GSP branch would TypeError if ever exercised. Same latent bug existed before; worth filing as a separate issue since the new TD3 test coverage suggested below would trip on it.

Test gaps (suggestions, not blockers)

tests/test_actor/test_gsp_loss_exposure.py covers DDPG-gsp and the gsp-disabled control. Missing:

  • Reset between ticks: call learn() twice, second call with an empty gsp buffer, assert last_gsp_loss returns to None. This is the load-bearing invariant of the reset.
  • Attention variant: the one branch where the returned loss really is a prediction MSE, and the one most relevant to the collapse diagnostic.
  • TD3 tuple path: construct an actor with UPDATE_ACTOR_ITER > 1 and step twice to hit both the (0, 0) branch and the scalar branch. (Blocked by concern 3 above for the GSP case, but works for primary-network TD3 as an isolated test of the normalization logic.)
  • RDDPG variant: untested entirely.

Nits

  • actor.py:446 uses in {'DDPG'} / in {'RDDPG'} (single-element sets) where == 'DDPG' would read more cleanly. Pre-existing, not introduced here.

…reset test

- learn_TD3 now accepts recurrent=False to match the DDPG/RDDPG signatures;
  the learn_gsp dispatch was passing 3 positional args to a 2-arg method.
  Latent bug today (GSP networks are built as DDPG/attention, not TD3) but
  removes the footgun before the diagnostic batch exercises TD3 variants.
- TD3's non-actor-update step returns (0, 0); previously we unwrapped to
  0.0 and logged it. That produces false collapse signals every
  update_actor_iter-1 ticks. Now we skip the entry entirely — leave
  last_gsp_loss at None as if no GSP step ran.
- Doc the semantic: last_gsp_loss is the GSP learner's training loss,
  which is actor loss (policy-gradient signal) for DDPG/RDDPG/TD3 and
  genuine MSE only for attention. For prediction-collapse detection
  consumers should rely on gsp_squared_error and the HDF5Logger
  episode-level gsp_output_std / gsp_pred_target_corr attrs.
- Add reset-between-ticks test covering the load-bearing invariant that
  last_gsp_loss returns to None when a learn() call runs but no GSP
  learning step fires.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jdbloom
Copy link
Copy Markdown
Contributor Author

jdbloom commented Apr 13, 2026

Second-pass review of e7f0a55 — verdict: ready to merge

All four fixes land cleanly. Verified against the tree at e7f0a55.

Per-concern verification

#3 TD3 signature (learning_aids.py:388) — fixed correctly.

  • (a) Signature learn_TD3(self, networks, gsp=False, recurrent=False) now matches learn_DDPG:280 and learn_RDDPG:308 exactly.
  • (b) recurrent is accepted but unused inside the body — same posture as learn_DDPG, which also ignores it. Consistent.
  • (c) Only non-GSP caller is actor.py:447 (self.learn_TD3(self.networks)) which passes one positional arg; defaults absorb the rest. No breakage.

#2 TD3 tuple path (actor.py:468-469) — fixed correctly.

  • (a) Reset happens in learn():421 (self.last_gsp_loss = None) before learn_gsp() is called on line 430. Bare return from the tuple branch leaves the already-None value intact. Correct.
  • (b) learn_gsp has no further mutation of last_gsp_loss after line 469. Clean.
  • (c) Documented semantic ("no GSP step contributed this tick") matches behavior.
  • Note: the outer if loss is not None: check on line 463 is not dead — DDPG/RDDPG/attention still return scalars and fall through to float(loss) on line 470. Only the tuple sub-branch short-circuits.

#1 Docstring (actor.py:116-125) — accurate.

  • learn_DDPG:306 returns actor_loss.item() (policy-gradient loss, line 299-301: -critic(s, π(s)).mean()). Correct.
  • learn_RDDPG:386 returns actor_loss.item() (same pattern, line 380). Correct.
  • learn_TD3:438 returns actor_loss.item() on actor-update steps (line 432: -mean(critic_1(s, π(s)))). Correct.
  • learn_attention:450 returns loss.item() where loss = Loss(pred_headings, labels) — genuine prediction MSE. Correct.
  • The redirect to gsp_squared_error / gsp_output_std / gsp_pred_target_corr for collapse detection is the right guidance.

#4 Reset test — valid probe.

  • mem_ctr is the right guard field: learn_gsp():450-451 early-returns on self.gsp_networks['replay'].mem_ctr < self.gsp_batch_size before any loss assignment.
  • The early-return skips both loss = ... and the if loss is not None block, so last_gsp_loss stays at the None set by learn():421.
  • The test would genuinely fail if the reset invariant broke: if line 421 were removed, the stale float from the first learn() call would persist through the second call's assertion. Real coverage, not a tautology.

New issues introduced by e7f0a55

None. No new typos, no None-deref risk, no dead branches. The three original tests in test_gsp_loss_exposure.py are untouched by the actor.py changes (their code paths still hit self.last_gsp_loss = float(loss) on line 470).

Recommendation

Ready to merge. All review concerns addressed correctly, no regressions introduced.

@jdbloom jdbloom merged commit e63bf21 into feat/rddpg-lstm-fix Apr 13, 2026
@jdbloom jdbloom deleted the feature/hdf5-gsp-diagnostics branch April 13, 2026 00:32
jdbloom pushed a commit that referenced this pull request Apr 13, 2026
…reset test

- learn_TD3 now accepts recurrent=False to match the DDPG/RDDPG signatures;
  the learn_gsp dispatch was passing 3 positional args to a 2-arg method.
  Latent bug today (GSP networks are built as DDPG/attention, not TD3) but
  removes the footgun before the diagnostic batch exercises TD3 variants.
- TD3's non-actor-update step returns (0, 0); previously we unwrapped to
  0.0 and logged it. That produces false collapse signals every
  update_actor_iter-1 ticks. Now we skip the entry entirely — leave
  last_gsp_loss at None as if no GSP step ran.
- Doc the semantic: last_gsp_loss is the GSP learner's training loss,
  which is actor loss (policy-gradient signal) for DDPG/RDDPG/TD3 and
  genuine MSE only for attention. For prediction-collapse detection
  consumers should rely on gsp_squared_error and the HDF5Logger
  episode-level gsp_output_std / gsp_pred_target_corr attrs.
- Add reset-between-ticks test covering the load-bearing invariant that
  last_gsp_loss returns to None when a learn() call runs but no GSP
  learning step fires.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant