Fix R-GSP-N: LSTM hidden state + vectorized learning + stored state (R2D2)#20
Merged
Fix R-GSP-N: LSTM hidden state + vectorized learning + stored state (R2D2)#20
Conversation
… RDDPG networks EnvironmentEncoder.forward now accepts optional (h_0, c_0) and returns (output, (h_n, c_n)) instead of discarding hidden state. This enables temporal memory across timesteps for RDDPG. Both RDDPGActorNetwork and RDDPGCriticNetwork forward methods updated to pass through hidden state. All callers in learn_RDDPG, DDPG_choose_action, and test files updated to unpack the new tuple return value. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…yle) Add optional hidden_size/num_layers params to SequenceReplayBuffer so the LSTM (h, c) state at the start of each sequence is stored alongside the SARSD data. sample_buffer returns a 7-tuple when hidden storage is enabled and a 5-tuple otherwise, preserving full backward compatibility. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace per-sample loop in learn_RDDPG with batched implementation using R2D2-style burn-in. Splits sequences into burn-in prefix (first half) and training suffix (second half), refreshes LSTM hidden state during burn-in with no_grad, then computes critic/actor loss on the last timestep of the training suffix. Also updates sample_memory to handle 7-value returns from SequenceReplayBuffer (with hidden states) and passes hidden_size/num_layers to SequenceReplayBuffer in build_gsp_network. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…G works on MPS (80ms vs 150ms CPU)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Three-part fix for the R-GSP-N recurrent pipeline:
1. LSTM Hidden State Fix
EnvironmentEncoder.forward()now returns(output, (h_n, c_n))and acceptshidden=(h_0, c_0). Previously discarded hidden state every call (lstm_out, _ = ...), meaning R-GSP-N had NO temporal memory beyond the sliding observation window. Supports both single(seq_len, input)and batched(batch, seq_len, input)inputs.2. SequenceReplayBuffer Stores Hidden State (R2D2-style)
New
hidden_sizeandnum_layersparams. When enabled, stores(h, c)at sequence boundaries viaset_sequence_hidden().sample_buffer()returns 7 values(s, a, r, s_, d, h, c)instead of 5. Backward compatible — buffers without hidden_size work as before.3. Vectorized learn_RDDPG with Burn-in
Replaces per-sample loop (1,500ms) with single batched forward/backward pass (<200ms).
T.no_grad()T.no_grad()for gradient isolationReference: R2D2 (Kapturowski et al., ICLR 2019) — stored state + burn-in.
Performance
Tests
Breaking Changes
EnvironmentEncoder.forward()returns tuple(output, hidden)instead of single tensorRDDPGActorNetwork.forward()returns(mu, hidden)instead ofmuRDDPGCriticNetwork.forward()returns(q_value, hidden)instead ofq_value🤖 Generated with Claude Code