From b634cef7aa8ac6970a3e8dcc3aca23c66283b2f3 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 20 Apr 2026 14:10:59 +0200 Subject: [PATCH 1/9] extending RB interface to emit weights and indices as well, allowing to implement PER buffer Signed-off-by: Thorsten Kurth --- .../include/internal/rl/off_policy/ddpg.h | 12 +++---- src/csrc/include/internal/rl/off_policy/sac.h | 15 ++++---- src/csrc/include/internal/rl/off_policy/td3.h | 16 ++++----- src/csrc/include/internal/rl/replay_buffer.h | 17 +++++++--- src/csrc/rl/off_policy/ddpg.cpp | 13 ++++--- src/csrc/rl/off_policy/sac.cpp | 14 +++++--- src/csrc/rl/off_policy/td3.cpp | 13 ++++--- tests/rl/test_replay_buffer.cpp | 34 ++++++++++++------- 8 files changed, 82 insertions(+), 52 deletions(-) diff --git a/src/csrc/include/internal/rl/off_policy/ddpg.h b/src/csrc/include/internal/rl/off_policy/ddpg.h index 4c0190a4..05b93a8c 100644 --- a/src/csrc/include/internal/rl/off_policy/ddpg.h +++ b/src/csrc/include/internal/rl/off_policy/ddpg.h @@ -52,7 +52,8 @@ template void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const ModelPack& q_model, const ModelPack& q_model_target, torch::Tensor state_old_tensor, torch::Tensor state_new_tensor, torch::Tensor action_old_tensor, torch::Tensor action_new_tensor, torch::Tensor reward_tensor, - torch::Tensor d_tensor, const T& gamma, const T& rho, T& p_loss_val, T& q_loss_val) { + torch::Tensor d_tensor, torch::Tensor is_weights, const T& gamma, const T& rho, + torch::Tensor& td_errors, T& p_loss_val, T& q_loss_val) { // nvtx marker torchfort::nvtx::rangePush("torchfort_train_ddpg"); @@ -72,10 +73,6 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const // value functions q_model.model->train(); - // opt - // loss is fixed by algorithm - auto q_loss_func = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); - // policy function // compute y: use the target models for q_new, no grads torch::Tensor y_tensor; @@ -87,10 +84,11 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const } // backward and update step - // compute loss + // IS-weighted MSE loss: mean(w * (q - y)^2) torch::Tensor q_old_tensor = torch::squeeze(q_model.model->forward(std::vector{state_old_tensor, action_old_tensor})[0], 1); - torch::Tensor q_loss_tensor = q_loss_func->forward(q_old_tensor, y_tensor); + td_errors = torch::abs(q_old_tensor - y_tensor).detach(); + torch::Tensor q_loss_tensor = torch::mean(is_weights * torch::square(q_old_tensor - y_tensor)); auto state = q_model.state; if (state->step_train_current % q_model.grad_accumulation_steps == 0) { diff --git a/src/csrc/include/internal/rl/off_policy/sac.h b/src/csrc/include/internal/rl/off_policy/sac.h index 677059f2..e0b352a0 100644 --- a/src/csrc/include/internal/rl/off_policy/sac.h +++ b/src/csrc/include/internal/rl/off_policy/sac.h @@ -56,10 +56,10 @@ template void train_sac(const PolicyPack& p_model, const std::vector& q_models, const std::vector& q_models_target, torch::Tensor state_old_tensor, torch::Tensor state_new_tensor, torch::Tensor action_old_tensor, torch::Tensor reward_tensor, - torch::Tensor d_tensor, const std::shared_ptr& alpha_model, + torch::Tensor d_tensor, torch::Tensor is_weights, const std::shared_ptr& alpha_model, const std::shared_ptr& alpha_optimizer, const std::shared_ptr& alpha_lr_scheduler, const T& target_entropy, const T& gamma, - const T& rho, T& p_loss_val, T& q_loss_val) { + const T& rho, torch::Tensor& td_errors, T& p_loss_val, T& q_loss_val) { // nvtx marker torchfort::nvtx::rangePush("torchfort_train_sac"); @@ -84,10 +84,6 @@ void train_sac(const PolicyPack& p_model, const std::vector& q_models q_model_target.model->train(); } - // opt - // loss is fixed by algorithm - auto q_loss_func = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); - // if we are updating the entropy coefficient, do that first torch::Tensor alpha_loss; auto state = p_model.state; @@ -168,9 +164,12 @@ void train_sac(const PolicyPack& p_model, const std::vector& q_models } // backward and update step + // IS-weighted MSE loss: mean(w * (q - y)^2), summed across critics + // td_errors taken from first critic only torch::Tensor q_old_tensor = torch::squeeze(q_models[0].model->forward(std::vector{state_old_tensor, action_old_tensor})[0], 1); - torch::Tensor q_loss_tensor = q_loss_func->forward(q_old_tensor, y_tensor); + td_errors = torch::abs(q_old_tensor - y_tensor).detach(); + torch::Tensor q_loss_tensor = torch::mean(is_weights * torch::square(q_old_tensor - y_tensor)); state = q_models[0].state; if (state->step_train_current % q_models[0].grad_accumulation_steps == 0) { q_models[0].optimizer->zero_grad(); @@ -179,7 +178,7 @@ void train_sac(const PolicyPack& p_model, const std::vector& q_models // compute loss q_old_tensor = torch::squeeze( q_models[i].model->forward(std::vector{state_old_tensor, action_old_tensor})[0], 1); - q_loss_tensor = q_loss_tensor + q_loss_func->forward(q_old_tensor, y_tensor); + q_loss_tensor = q_loss_tensor + torch::mean(is_weights * torch::square(q_old_tensor - y_tensor)); state = q_models[i].state; if (state->step_train_current % q_models[i].grad_accumulation_steps == 0) { q_models[i].optimizer->zero_grad(); diff --git a/src/csrc/include/internal/rl/off_policy/td3.h b/src/csrc/include/internal/rl/off_policy/td3.h index 1d267028..1af8a6e4 100644 --- a/src/csrc/include/internal/rl/off_policy/td3.h +++ b/src/csrc/include/internal/rl/off_policy/td3.h @@ -52,8 +52,8 @@ template void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const std::vector& q_models, const std::vector& q_models_target, torch::Tensor state_old_tensor, torch::Tensor state_new_tensor, torch::Tensor action_old_tensor, torch::Tensor action_new_tensor, - torch::Tensor reward_tensor, torch::Tensor d_tensor, const T& gamma, const T& rho, T& p_loss_val, - T& q_loss_val, bool update_policy) { + torch::Tensor reward_tensor, torch::Tensor d_tensor, torch::Tensor is_weights, const T& gamma, + const T& rho, torch::Tensor& td_errors, T& p_loss_val, T& q_loss_val, bool update_policy) { // nvtx marker torchfort::nvtx::rangePush("torchfort_train_td3"); @@ -76,10 +76,6 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const q_model.model->train(); } - // opt - // loss is fixed by algorithm - auto q_loss_func = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); - // policy function // compute y: use the target models for q_new, no grads torch::Tensor y_tensor; @@ -96,10 +92,12 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const } // backward and update step - // compute loss for critics and zero grads while we are at it + // IS-weighted MSE loss: mean(w * (q - y)^2), summed across critics + // td_errors taken from first critic only (consistent with policy update using q_models[0]) torch::Tensor q_old_tensor = torch::squeeze(q_models[0].model->forward(std::vector{state_old_tensor, action_old_tensor})[0], 1); - torch::Tensor q_loss_tensor = q_loss_func->forward(q_old_tensor, y_tensor); + td_errors = torch::abs(q_old_tensor - y_tensor).detach(); + torch::Tensor q_loss_tensor = torch::mean(is_weights * torch::square(q_old_tensor - y_tensor)); auto state = q_models[0].state; if (state->step_train_current % q_models[0].grad_accumulation_steps == 0) { q_models[0].optimizer->zero_grad(); @@ -107,7 +105,7 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const for (int i = 1; i < q_models.size(); ++i) { q_old_tensor = torch::squeeze( q_models[i].model->forward(std::vector{state_old_tensor, action_old_tensor})[0], 1); - q_loss_tensor = q_loss_tensor + q_loss_func->forward(q_old_tensor, y_tensor); + q_loss_tensor = q_loss_tensor + torch::mean(is_weights * torch::square(q_old_tensor - y_tensor)); state = q_models[i].state; if (state->step_train_current % q_models[i].grad_accumulation_steps == 0) { q_models[i].optimizer->zero_grad(); diff --git a/src/csrc/include/internal/rl/replay_buffer.h b/src/csrc/include/internal/rl/replay_buffer.h index fc7c3883..37e7d238 100644 --- a/src/csrc/include/internal/rl/replay_buffer.h +++ b/src/csrc/include/internal/rl/replay_buffer.h @@ -30,6 +30,8 @@ namespace torchfort { namespace rl { using BufferEntry = std::tuple; +using SampleResult = std::tuple; enum RewardReductionMode { Sum = 1, Mean = 2, WeightedMean = 3, SumNoSkip = 4, MeanNoSkip = 5, WeightedMeanNoSkip = 6 }; @@ -55,8 +57,11 @@ class ReplayBuffer { // virtual functions virtual void update(torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor) = 0; - // sample element randomly - virtual std::tuple sample(int) = 0; + // sample elements: returns (s, a, s', r, d, is_weights, indices) + // is_weights are all-ones and indices are all-zeros for uniform sampling + virtual SampleResult sample(int) = 0; + // update priorities based on per-sample TD errors; no-op for uniform sampling + virtual void update_priorities(torch::Tensor /* indices */, torch::Tensor /* td_errors */) {} // get specific element virtual BufferEntry get(int) = 0; // helper functions @@ -134,7 +139,7 @@ class UniformReplayBuffer : public ReplayBuffer, public std::enable_shared_from_ } } - std::tuple sample(int batch_size) { + SampleResult sample(int batch_size) { // add no grad guard torch::NoGradGuard no_grad; @@ -227,7 +232,11 @@ class UniformReplayBuffer : public ReplayBuffer, public std::enable_shared_from_ rtens = torch::stack(rtens_list, 0).clone(); dtens = torch::stack(dtens_list, 0).clone(); - return std::make_tuple(stens, atens, sptens, rtens, dtens); + // uniform buffer: unit IS weights, zero indices (ignored by no-op update_priorities) + auto is_weights = torch::ones(batch_size, torch::TensorOptions().dtype(torch::kFloat32).device(device_)); + auto indices = torch::zeros(batch_size, torch::TensorOptions().dtype(torch::kLong).device(device_)); + + return std::make_tuple(stens, atens, sptens, rtens, dtens, is_weights, indices); } BufferEntry get(int index) { diff --git a/src/csrc/rl/off_policy/ddpg.cpp b/src/csrc/rl/off_policy/ddpg.cpp index c6aeb9ca..4f809474 100644 --- a/src/csrc/rl/off_policy/ddpg.cpp +++ b/src/csrc/rl/off_policy/ddpg.cpp @@ -533,12 +533,12 @@ void DDPGSystem::trainStep(float& p_loss_val, float& q_loss_val) { train_step_count_++; // get samples from replay buffer - torch::Tensor s, a, ap, sp, r, d; + torch::Tensor s, a, ap, sp, r, d, is_weights, sample_indices; { torch::NoGradGuard no_grad; // get a sample from the replay buffer - std::tie(s, a, sp, r, d) = replay_buffer_->sample(batch_size_); + std::tie(s, a, sp, r, d, is_weights, sample_indices) = replay_buffer_->sample(batch_size_); // upload to device s = s.to(model_device_); @@ -546,6 +546,7 @@ void DDPGSystem::trainStep(float& p_loss_val, float& q_loss_val) { sp = sp.to(model_device_); r = r.to(model_device_); d = d.to(model_device_); + is_weights = is_weights.to(model_device_); // sync and apply state normalization if (state_normalizer_) { @@ -565,8 +566,12 @@ void DDPGSystem::trainStep(float& p_loss_val, float& q_loss_val) { } // train step - train_ddpg(p_model_, p_model_target_, q_model_, q_model_target_, s, sp, a, ap, r, d, - static_cast(std::pow(gamma_, nstep_)), rho_, p_loss_val, q_loss_val); + torch::Tensor td_errors; + train_ddpg(p_model_, p_model_target_, q_model_, q_model_target_, s, sp, a, ap, r, d, is_weights, + static_cast(std::pow(gamma_, nstep_)), rho_, td_errors, p_loss_val, q_loss_val); + + // update priorities (no-op for uniform buffer) + replay_buffer_->update_priorities(sample_indices, td_errors); } } // namespace off_policy diff --git a/src/csrc/rl/off_policy/sac.cpp b/src/csrc/rl/off_policy/sac.cpp index 99e7ac19..3654357c 100644 --- a/src/csrc/rl/off_policy/sac.cpp +++ b/src/csrc/rl/off_policy/sac.cpp @@ -700,12 +700,12 @@ void SACSystem::trainStep(float& p_loss_val, float& q_loss_val) { train_step_count_++; // we need these - torch::Tensor s, a, ap, sp, r, d; + torch::Tensor s, a, ap, sp, r, d, is_weights, sample_indices; { torch::NoGradGuard no_grad; // get a sample from the replay buffer - std::tie(s, a, sp, r, d) = replay_buffer_->sample(batch_size_); + std::tie(s, a, sp, r, d, is_weights, sample_indices) = replay_buffer_->sample(batch_size_); // upload to device s = s.to(model_device_); @@ -713,6 +713,7 @@ void SACSystem::trainStep(float& p_loss_val, float& q_loss_val) { sp = sp.to(model_device_); r = r.to(model_device_); d = d.to(model_device_); + is_weights = is_weights.to(model_device_); // sync and apply state normalization if (state_normalizer_) { @@ -729,8 +730,13 @@ void SACSystem::trainStep(float& p_loss_val, float& q_loss_val) { } // train step - train_sac(p_model_, q_models_, q_models_target_, s, sp, a, r, d, alpha_model_, alpha_optimizer_, alpha_lr_scheduler_, - target_entropy_, static_cast(std::pow(gamma_, nstep_)), rho_, p_loss_val, q_loss_val); + torch::Tensor td_errors; + train_sac(p_model_, q_models_, q_models_target_, s, sp, a, r, d, is_weights, alpha_model_, alpha_optimizer_, + alpha_lr_scheduler_, target_entropy_, static_cast(std::pow(gamma_, nstep_)), rho_, td_errors, + p_loss_val, q_loss_val); + + // update priorities (no-op for uniform buffer) + replay_buffer_->update_priorities(sample_indices, td_errors); } } // namespace off_policy diff --git a/src/csrc/rl/off_policy/td3.cpp b/src/csrc/rl/off_policy/td3.cpp index 8ad9bee9..09957054 100644 --- a/src/csrc/rl/off_policy/td3.cpp +++ b/src/csrc/rl/off_policy/td3.cpp @@ -607,12 +607,12 @@ void TD3System::trainStep(float& p_loss_val, float& q_loss_val) { // update policy? bool update_policy = (train_step_count_ % policy_lag_ == 0); - torch::Tensor s, a, ap, sp, r, d; + torch::Tensor s, a, ap, sp, r, d, is_weights, sample_indices; { torch::NoGradGuard no_grad; // get a sample from the replay buffer - std::tie(s, a, sp, r, d) = replay_buffer_->sample(batch_size_); + std::tie(s, a, sp, r, d, is_weights, sample_indices) = replay_buffer_->sample(batch_size_); // upload to device s = s.to(model_device_); @@ -620,6 +620,7 @@ void TD3System::trainStep(float& p_loss_val, float& q_loss_val) { sp = sp.to(model_device_); r = r.to(model_device_); d = d.to(model_device_); + is_weights = is_weights.to(model_device_); // sync and apply state normalization if (state_normalizer_) { @@ -639,8 +640,12 @@ void TD3System::trainStep(float& p_loss_val, float& q_loss_val) { } // train step - train_td3(p_model_, p_model_target_, q_models_, q_models_target_, s, sp, a, ap, r, d, - static_cast(std::pow(gamma_, nstep_)), rho_, p_loss_val, q_loss_val, update_policy); + torch::Tensor td_errors; + train_td3(p_model_, p_model_target_, q_models_, q_models_target_, s, sp, a, ap, r, d, is_weights, + static_cast(std::pow(gamma_, nstep_)), rho_, td_errors, p_loss_val, q_loss_val, update_policy); + + // update priorities (no-op for uniform buffer) + replay_buffer_->update_priorities(sample_indices, td_errors); } } // namespace off_policy diff --git a/tests/rl/test_replay_buffer.cpp b/tests/rl/test_replay_buffer.cpp index 2e60e470..a2e94d9f 100644 --- a/tests/rl/test_replay_buffer.cpp +++ b/tests/rl/test_replay_buffer.cpp @@ -103,8 +103,8 @@ TEST_P(ReplayBuffer, ShapeConsistency) { auto rbuff = getTestReplayBuffer(buffer_size, n_envs, 0.95, 1); // sample - torch::Tensor stens, atens, sptens, rtens, dtens; - std::tie(stens, atens, sptens, rtens, dtens) = rbuff->sample(batch_size); + torch::Tensor stens, atens, sptens, rtens, dtens, is_weights, sample_indices; + std::tie(stens, atens, sptens, rtens, dtens, is_weights, sample_indices) = rbuff->sample(batch_size); // check shapes EXPECT_EQ(stens.dim(), 2); @@ -112,15 +112,25 @@ TEST_P(ReplayBuffer, ShapeConsistency) { EXPECT_EQ(sptens.dim(), 2); EXPECT_EQ(rtens.dim(), 1); EXPECT_EQ(dtens.dim(), 1); + EXPECT_EQ(is_weights.dim(), 1); + EXPECT_EQ(sample_indices.dim(), 1); EXPECT_EQ(stens.size(0), batch_size); EXPECT_EQ(atens.size(0), batch_size); EXPECT_EQ(sptens.size(0), batch_size); EXPECT_EQ(rtens.size(0), batch_size); EXPECT_EQ(dtens.size(0), batch_size); + EXPECT_EQ(is_weights.size(0), batch_size); + EXPECT_EQ(sample_indices.size(0), batch_size); EXPECT_EQ(stens.size(1), 1); EXPECT_EQ(atens.size(1), 1); + + // uniform buffer: weights must be all-ones, dtype float; indices dtype long + EXPECT_EQ(is_weights.scalar_type(), torch::kFloat32); + EXPECT_EQ(sample_indices.scalar_type(), torch::kLong); + EXPECT_FLOAT_EQ(is_weights.min().item(), 1.f); + EXPECT_FLOAT_EQ(is_weights.max().item(), 1.f); } // check if entries are consistent @@ -138,11 +148,11 @@ TEST_P(ReplayBuffer, EntryConsistency) { auto rbuff = getTestReplayBuffer(buffer_size, n_envs, 0.95, 1); // sample - torch::Tensor stens, atens, sptens, rtens, dtens; + torch::Tensor stens, atens, sptens, rtens, dtens, is_weights, sample_indices; float state_diff = 0; float reward_diff = 0.; for (unsigned int i = 0; i < 4; ++i) { - std::tie(stens, atens, sptens, rtens, dtens) = rbuff->sample(batch_size); + std::tie(stens, atens, sptens, rtens, dtens, is_weights, sample_indices) = rbuff->sample(batch_size); // compute differences: state_diff += torch::sum(torch::abs(stens + atens - sptens)).item(); @@ -202,10 +212,10 @@ TEST_P(ReplayBuffer, NStepConsistency) { auto rbuff = getTestReplayBuffer(buffer_size, n_envs, gamma, nstep); // sample a batch - torch::Tensor stens, atens, sptens, rtens, dtens; + torch::Tensor stens, atens, sptens, rtens, dtens, is_weights, sample_indices; float state_diff = 0; float reward_diff = 0.; - std::tie(stens, atens, sptens, rtens, dtens) = rbuff->sample(batch_size); + std::tie(stens, atens, sptens, rtens, dtens, is_weights, sample_indices) = rbuff->sample(batch_size); // iterate over samples in batch torch::Tensor stemp, atemp, sptemp, sstens, rtemp, dtemp, spfin; @@ -339,8 +349,8 @@ TEST(RewardNormalization, UnitStdPreservedMean) { // Sample a batch and normalize rewards as trainStep does const int batch_size = 256; - torch::Tensor s, a, sp, r, d; - std::tie(s, a, sp, r, d) = rbuff->sample(batch_size); + torch::Tensor s, a, sp, r, d, is_weights, sample_indices; + std::tie(s, a, sp, r, d, is_weights, sample_indices) = rbuff->sample(batch_size); // mirror the system's trainStep normalization auto r_norm = reward_normalizer.normalize(r.unsqueeze(1)).squeeze(1); @@ -383,8 +393,8 @@ TEST(RewardNormalization, SignPreservation) { state = next_state; } - torch::Tensor s, a, sp, r, d; - std::tie(s, a, sp, r, d) = rbuff->sample(buffer_size); + torch::Tensor s, a, sp, r, d, is_weights, sample_indices; + std::tie(s, a, sp, r, d, is_weights, sample_indices) = rbuff->sample(buffer_size); auto r_norm = reward_normalizer.normalize(r.unsqueeze(1)).squeeze(1); // all normalized rewards must remain positive @@ -421,8 +431,8 @@ TEST(RewardNormalization, LargeScaleNormalizedToUnitStd) { state = next_state; } - torch::Tensor s, a, sp, r, d; - std::tie(s, a, sp, r, d) = rbuff->sample(buffer_size); + torch::Tensor s, a, sp, r, d, is_weights, sample_indices; + std::tie(s, a, sp, r, d, is_weights, sample_indices) = rbuff->sample(buffer_size); auto r_norm = reward_normalizer.normalize(r.unsqueeze(1)).squeeze(1); // std should be close to 1 regardless of the original reward scale From 67baf8e572c839944d753375527fe81f3faca957 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 20 Apr 2026 16:01:31 +0200 Subject: [PATCH 2/9] implemented PER includinng tests and interface helpers Signed-off-by: Thorsten Kurth --- CMakeLists.txt | 1 + src/csrc/include/internal/rl/replay_buffer.h | 378 +++++++++++++++++++ src/csrc/include/internal/rl/setup.h | 39 ++ src/csrc/rl/off_policy/ddpg.cpp | 23 +- src/csrc/rl/off_policy/sac.cpp | 23 +- src/csrc/rl/off_policy/td3.cpp | 23 +- src/csrc/rl/setup.cpp | 65 ++++ tests/rl/test_replay_buffer.cpp | 238 ++++++++++++ 8 files changed, 730 insertions(+), 60 deletions(-) create mode 100644 src/csrc/include/internal/rl/setup.h create mode 100644 src/csrc/rl/setup.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f6ed1baa..3ce29b5e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -165,6 +165,7 @@ target_sources(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/rl/sac_model.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/policy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/running_normalizer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/setup.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/off_policy/interface.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/off_policy/ddpg.cpp diff --git a/src/csrc/include/internal/rl/replay_buffer.h b/src/csrc/include/internal/rl/replay_buffer.h index 37e7d238..ed82a27d 100644 --- a/src/csrc/include/internal/rl/replay_buffer.h +++ b/src/csrc/include/internal/rl/replay_buffer.h @@ -353,5 +353,383 @@ class UniformReplayBuffer : public ReplayBuffer, public std::enable_shared_from_ bool skip_incomplete_steps_; }; +// Prioritized Experience Replay buffer (Schaul et al., 2015). +// +// Priorities are stored in a sum-tree so that O(log N) sampling and updates +// are possible. Each physical buffer slot holds one timestep for all n_envs_ +// environments. Priorities operate at the step level; the env is chosen +// uniformly given the step. +// +// IS weight formula (normalised to [0,1]): +// w_i = (min_p_alpha / p_i)^beta +// where p_i = p_raw^alpha is the value stored in the tree leaf and +// min_p_alpha is the minimum non-zero p^alpha currently tracked in the buffer. +// max weight = 1 by construction (achieved by the lowest-priority entry). +// +// When a new entry is written at position write_pos_: +// 1. Set its tree priority to 0 (not yet a valid n-step starting point). +// 2. Enable position (write_pos_ - nstep_ + 1) % max_size_ with max_priority +// (it now has nstep_ consecutive entries ahead of it). +// This keeps the circular-buffer n-step invariant correct without any +// extra bookkeeping. +class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_from_this { + +public: + PrioritizedReplayBuffer(size_t max_size, size_t min_size, size_t n_envs, float gamma, int nstep, + RewardReductionMode reward_reduction_mode, float alpha, float beta0, float beta_max, + size_t beta_steps, int device) + : ReplayBuffer(max_size, min_size, n_envs, device), gamma_(gamma), nstep_(nstep), alpha_(alpha), beta_(beta0), + beta_max_(beta_max), + beta_increment_(beta_steps > 0 ? (beta_max - beta0) / static_cast(beta_steps) : 0.f), + epsilon_(1e-6f), max_priority_(1.f), min_p_alpha_(std::numeric_limits::max()), + write_pos_(0), current_size_(0) { + + skip_incomplete_steps_ = true; + if (reward_reduction_mode == RewardReductionMode::MeanNoSkip) { + reward_reduction_mode_ = RewardReductionMode::Mean; + skip_incomplete_steps_ = false; + } else if (reward_reduction_mode == RewardReductionMode::WeightedMeanNoSkip) { + reward_reduction_mode_ = RewardReductionMode::WeightedMean; + skip_incomplete_steps_ = false; + } else if (reward_reduction_mode == RewardReductionMode::SumNoSkip) { + reward_reduction_mode_ = RewardReductionMode::Sum; + skip_incomplete_steps_ = false; + } else { + reward_reduction_mode_ = reward_reduction_mode; + } + + buffer_.resize(max_size_); + // 1-indexed sum-tree: root at 1, leaves at [max_size_, 2*max_size_ - 1] + sum_tree_.assign(2 * max_size_, 0.f); + priorities_.assign(max_size_, 0.f); + } + + PrioritizedReplayBuffer(const PrioritizedReplayBuffer&) = delete; + + void update(torch::Tensor s, torch::Tensor a, torch::Tensor sp, torch::Tensor r, torch::Tensor d) override { + torch::NoGradGuard no_grad; + + if ((s.sizes()[0] != n_envs_) || (a.sizes()[0] != n_envs_) || (sp.sizes()[0] != n_envs_)) { + throw std::runtime_error( + "PrioritizedReplayBuffer::update: leading dimension of s, a, sp must equal n_envs"); + } + if ((r.sizes()[0] != n_envs_) || (d.sizes()[0] != n_envs_)) { + throw std::runtime_error( + "PrioritizedReplayBuffer::update: leading dimension of r, d must equal n_envs"); + } + + auto sc = s.to(device_, s.dtype(), false, true); + auto ac = a.to(device_, a.dtype(), false, true); + auto spc = sp.to(device_, sp.dtype(), false, true); + auto rc = r.to(device_, r.dtype(), false, true); + auto dc = d.to(device_, d.dtype(), false, true); + + // write entry; zero priority until it becomes a valid n-step starting point + size_t pos = write_pos_; + buffer_[pos] = std::make_tuple(sc, ac, spc, rc, dc); + treeUpdate_(pos, 0.f); + priorities_[pos] = 0.f; + + write_pos_ = (write_pos_ + 1) % max_size_; + current_size_ = std::min(current_size_ + 1, max_size_); + + // the position that can now start a complete n-step rollout ending at pos + if (current_size_ >= static_cast(nstep_)) { + size_t valid_pos = (pos + max_size_ - nstep_ + 1) % max_size_; + float p_alpha = std::pow(max_priority_, alpha_); + treeUpdate_(valid_pos, p_alpha); + priorities_[valid_pos] = max_priority_; + min_p_alpha_ = std::min(min_p_alpha_, p_alpha); + } + } + + SampleResult sample(int batch_size) override { + torch::NoGradGuard no_grad; + + auto stens_list = std::vector(batch_size); + auto atens_list = std::vector(batch_size); + auto sptens_list = std::vector(batch_size); + auto rtens_list = std::vector(batch_size); + auto dtens_list = std::vector(batch_size); + auto wtens_list = std::vector(batch_size); + auto itens_list = std::vector(batch_size); + + // anneal beta towards beta_max_ + beta_ = std::min(beta_max_, beta_ + beta_increment_); + + float total = treeTotal_(); + float segment = total / static_cast(batch_size); + // minimum p^alpha currently in the buffer — normalises IS weights so max weight = 1 + float min_p_alpha = min_p_alpha_; + + std::uniform_int_distribution env_dist(0, n_envs_ - 1); + + int s = 0; + while (s < batch_size) { + // stratified sampling: one draw per equal-width segment of the priority sum + float lo = segment * static_cast(s); + float hi = segment * static_cast(s + 1); + std::uniform_real_distribution seg_dist(lo, hi); + size_t pos = treeSample_(seg_dist(rng_)); + + // guard against numerical edge-cases (leaf with zero priority) + float p_i = sum_tree_[max_size_ + pos]; + if (p_i <= 0.f) { + continue; + } + + int64_t env_idx = static_cast(env_dist(rng_)); + + // IS weight: w_i = (min_p_alpha / p_i)^beta — max weight = 1 by construction + float weight = std::pow(min_p_alpha / p_i, beta_); + + // extract (state, action, ...) at (pos, env_idx) + torch::Tensor stens, atens, sptens, rtens, dtens; + std::tie(stens, atens, sptens, rtens, dtens) = buffer_[pos]; + stens_list[s] = stens.index({env_idx, "..."}).clone(); + atens_list[s] = atens.index({env_idx, "..."}).clone(); + sptens_list[s] = sptens.index({env_idx, "..."}).clone(); + rtens_list[s] = rtens.index({env_idx}).clone(); + dtens_list[s] = dtens.index({env_idx}).clone(); + + // n-step rollout — identical logic to UniformReplayBuffer, but uses modular indexing + float r_norm = 1.f; + int r_count = 1; + bool skip = false; + torch::Tensor deff = 1.f - dtens_list[s]; + for (int off = 1; off < nstep_; ++off) { + size_t next_pos = (pos + static_cast(off)) % max_size_; + std::tie(stens, atens, sptens, rtens, dtens) = buffer_[next_pos]; + sptens_list[s] = sptens.index({env_idx, "..."}).clone(); + float gamma_eff = static_cast(std::pow(gamma_, off)); + rtens_list[s] = rtens_list[s] + gamma_eff * rtens.index({env_idx}); + r_norm += gamma_eff; + r_count++; + + float d_val = dtens.index({env_idx}).item(); + if (std::abs(d_val - 1.f) < 1e-6f) { + deff = deff * 0.f; + if ((off != nstep_ - 1) && skip_incomplete_steps_) { + skip = true; + } + } + } + dtens_list[s] = (1.f - deff).clone(); + + if (skip) { + continue; // retry slot s without incrementing + } + + switch (reward_reduction_mode_) { + case RewardReductionMode::Mean: + rtens_list[s] = rtens_list[s] / static_cast(r_count); + break; + case RewardReductionMode::WeightedMean: + rtens_list[s] = rtens_list[s] / r_norm; + break; + default: + break; + } + + auto float_opts = torch::TensorOptions().dtype(torch::kFloat32).device(device_); + auto long_opts = torch::TensorOptions().dtype(torch::kLong).device(device_); + wtens_list[s] = torch::tensor(weight, float_opts); + itens_list[s] = torch::tensor(static_cast(pos), long_opts); + + ++s; + } + + return std::make_tuple(torch::stack(stens_list, 0).clone(), torch::stack(atens_list, 0).clone(), + torch::stack(sptens_list, 0).clone(), torch::stack(rtens_list, 0).clone(), + torch::stack(dtens_list, 0).clone(), torch::stack(wtens_list, 0).clone(), + torch::stack(itens_list, 0).clone()); + } + + void update_priorities(torch::Tensor indices, torch::Tensor td_errors) override { + torch::NoGradGuard no_grad; + + auto idx_acc = indices.to(torch::kCPU).accessor(); + auto td_acc = td_errors.to(torch::kCPU).accessor(); + + for (int64_t i = 0; i < idx_acc.size(0); ++i) { + size_t pos = static_cast(idx_acc[i]); + float raw_p = std::abs(td_acc[i]) + epsilon_; + float p_alpha = std::pow(raw_p, alpha_); + treeUpdate_(pos, p_alpha); + priorities_[pos] = raw_p; + max_priority_ = std::max(max_priority_, raw_p); + min_p_alpha_ = std::min(min_p_alpha_, p_alpha); + } + } + + BufferEntry get(int index) override { + if (index < 0 || static_cast(index) >= current_size_) { + throw std::runtime_error("PrioritizedReplayBuffer::get: index " + std::to_string(index) + + " out of bounds [0, " + std::to_string(current_size_) + ")."); + } + return buffer_[index]; + } + + bool isReady() const override { return current_size_ >= min_size_; } + + void reset() override { + current_size_ = 0; + write_pos_ = 0; + max_priority_ = 1.f; + min_p_alpha_ = std::numeric_limits::max(); + std::fill(sum_tree_.begin(), sum_tree_.end(), 0.f); + std::fill(priorities_.begin(), priorities_.end(), 0.f); + } + + size_t getSize() const override { return current_size_; } + + void setSeed(unsigned int seed) override { rng_.seed(seed); } + + torch::Device device() const override { return device_; } + + void printInfo() const override { + std::cout << "prioritized replay buffer parameters:" << std::endl; + std::cout << "max_size = " << max_size_ << std::endl; + std::cout << "min_size = " << min_size_ << std::endl; + std::cout << "n_envs = " << n_envs_ << std::endl; + std::cout << "alpha = " << alpha_ << std::endl; + std::cout << "beta (current) = " << beta_ << std::endl; + std::cout << "beta_max = " << beta_max_ << std::endl; + std::cout << "epsilon = " << epsilon_ << std::endl; + } + + void save(const std::string& fname) const override { + std::vector s_data, a_data, sp_data, r_data, d_data; + for (size_t i = 0; i < current_size_; ++i) { + torch::Tensor s, a, sp, r, d; + std::tie(s, a, sp, r, d) = buffer_[i]; + s_data.push_back(s.to(torch::kCPU)); + a_data.push_back(a.to(torch::kCPU)); + sp_data.push_back(sp.to(torch::kCPU)); + r_data.push_back(r.to(torch::kCPU)); + d_data.push_back(d.to(torch::kCPU)); + } + + std::filesystem::path root_dir(fname); + if (!std::filesystem::exists(root_dir)) { + bool rv = std::filesystem::create_directory(root_dir); + if (!rv) { + throw std::runtime_error("PrioritizedReplayBuffer::save: unable to create directory " + + root_dir.native() + "."); + } + } + + torch::save(s_data, root_dir / "s_data.pt"); + torch::save(a_data, root_dir / "a_data.pt"); + torch::save(sp_data, root_dir / "sp_data.pt"); + torch::save(r_data, root_dir / "r_data.pt"); + torch::save(d_data, root_dir / "d_data.pt"); + + // save raw priorities (before alpha) for positions 0..current_size_-1 + auto prio_tensor = torch::tensor( + std::vector(priorities_.begin(), priorities_.begin() + current_size_)); + torch::save({prio_tensor}, root_dir / "priorities.pt"); + + // save scalar state: [write_pos, current_size, beta, max_priority, min_p_alpha] + auto state_tensor = torch::tensor(std::vector{static_cast(write_pos_), + static_cast(current_size_), beta_, + max_priority_, min_p_alpha_}); + torch::save({state_tensor}, root_dir / "per_state.pt"); + } + + void load(const std::string& fname) override { + std::filesystem::path root_dir(fname); + + std::vector s_data, a_data, sp_data, r_data, d_data; + torch::load(s_data, root_dir / "s_data.pt"); + torch::load(a_data, root_dir / "a_data.pt"); + torch::load(sp_data, root_dir / "sp_data.pt"); + torch::load(r_data, root_dir / "r_data.pt"); + torch::load(d_data, root_dir / "d_data.pt"); + + // restore scalar state + std::vector state_vec; + torch::load(state_vec, root_dir / "per_state.pt"); + write_pos_ = static_cast(state_vec[0][0].item()); + current_size_ = static_cast(state_vec[0][1].item()); + beta_ = state_vec[0][2].item(); + max_priority_ = state_vec[0][3].item(); + min_p_alpha_ = state_vec[0][4].item(); + + // restore buffer entries at their physical positions + buffer_.assign(max_size_, BufferEntry{}); + for (size_t i = 0; i < s_data.size(); ++i) { + buffer_[i] = std::make_tuple(s_data[i], a_data[i], sp_data[i], r_data[i], d_data[i]); + } + + // restore priorities and rebuild sum-tree + std::vector prio_vec; + torch::load(prio_vec, root_dir / "priorities.pt"); + std::fill(priorities_.begin(), priorities_.end(), 0.f); + std::fill(sum_tree_.begin(), sum_tree_.end(), 0.f); + for (size_t i = 0; i < current_size_; ++i) { + float raw_p = prio_vec[0][static_cast(i)].item(); + priorities_[i] = raw_p; + if (raw_p > 0.f) { + treeUpdate_(i, std::pow(raw_p, alpha_)); + } + } + } + +private: + // update one leaf and propagate up to root — O(log N) + void treeUpdate_(size_t pos, float priority_alpha) { + size_t idx = max_size_ + pos; + sum_tree_[idx] = priority_alpha; + while (idx > 1) { + idx >>= 1; + sum_tree_[idx] = sum_tree_[2 * idx] + sum_tree_[2 * idx + 1]; + } + } + + // traverse the tree to find the leaf whose prefix-sum contains value — O(log N) + size_t treeSample_(float value) const { + // clamp to avoid floating-point overshoot past the last leaf + value = std::min(value, treeTotal_() * (1.f - 1e-6f)); + size_t idx = 1; + while (idx < max_size_) { + size_t left = 2 * idx; + if (value <= sum_tree_[left]) { + idx = left; + } else { + value -= sum_tree_[left]; + idx = left + 1; + } + } + return idx - max_size_; + } + + float treeTotal_() const { return sum_tree_[1]; } + + // circular buffer (vector for O(1) indexed access) + std::vector buffer_; + // 1-indexed sum-tree of size 2*max_size_; leaf for position p at index max_size_+p + std::vector sum_tree_; + // raw priorities (before alpha exponent), same indexing as buffer_ + std::vector priorities_; + + size_t write_pos_; + size_t current_size_; + + float alpha_; // priority exponent + float beta_; // IS weight exponent (annealed from beta0 toward beta_max_) + float beta_max_; + float beta_increment_; // added to beta_ each call to sample() + float epsilon_; // priority floor (prevents zero probability) + float max_priority_; // maximum raw priority seen; assigned to new entries + float min_p_alpha_; // minimum p^alpha currently tracked; used to normalise IS weights + + std::mt19937_64 rng_; + float gamma_; + int nstep_; + RewardReductionMode reward_reduction_mode_; + bool skip_incomplete_steps_; +}; + } // namespace rl } // namespace torchfort diff --git a/src/csrc/include/internal/rl/setup.h b/src/csrc/include/internal/rl/setup.h new file mode 100644 index 00000000..082f6d8b --- /dev/null +++ b/src/csrc/include/internal/rl/setup.h @@ -0,0 +1,39 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include "internal/rl/replay_buffer.h" + +namespace torchfort { + +namespace rl { + +// Construct a replay buffer from a YAML replay_buffer node. +// The node must contain a "type" field ("uniform" or "prioritized") and a "parameters" sub-node. +// gamma, nstep and nstep_reward_reduction are taken from the enclosing algorithm configuration +// and forwarded to the buffer constructor. +std::shared_ptr get_replay_buffer(const YAML::Node& rb_node, float gamma, int nstep, + RewardReductionMode nstep_reward_reduction, int rb_device); + +} // namespace rl + +} // namespace torchfort diff --git a/src/csrc/rl/off_policy/ddpg.cpp b/src/csrc/rl/off_policy/ddpg.cpp index 4f809474..6f3cad5f 100644 --- a/src/csrc/rl/off_policy/ddpg.cpp +++ b/src/csrc/rl/off_policy/ddpg.cpp @@ -20,6 +20,7 @@ #include "internal/exceptions.h" #include "internal/rl/off_policy/ddpg.h" +#include "internal/rl/setup.h" #include "torchfort.h" namespace torchfort { @@ -112,26 +113,8 @@ DDPGSystem::DDPGSystem(const char* name, const YAML::Node& system_node, int mode } if (system_node["replay_buffer"]) { - auto rb_node = system_node["replay_buffer"]; - std::string rb_type = sanitize(rb_node["type"].as()); - if (rb_node["parameters"]) { - auto params = get_params(rb_node["parameters"]); - std::set supported_params{"type", "max_size", "min_size", "n_envs"}; - check_params(supported_params, params.keys()); - auto max_size = static_cast(params.get_param("max_size")[0]); - auto min_size = static_cast(params.get_param("min_size")[0]); - auto n_envs = static_cast(params.get_param("n_envs", 1)[0]); - - // distinction between buffer types - if (rb_type == "uniform") { - replay_buffer_ = std::make_shared(max_size, min_size, n_envs, gamma_, nstep_, - nstep_reward_reduction_, rb_device); - } else { - THROW_INVALID_USAGE(rb_type); - } - } else { - THROW_INVALID_USAGE("Missing parameters section in replay_buffer section in configuration file."); - } + replay_buffer_ = rl::get_replay_buffer(system_node["replay_buffer"], gamma_, nstep_, + nstep_reward_reduction_, rb_device); } else { THROW_INVALID_USAGE("Missing replay_buffer section in configuration file."); } diff --git a/src/csrc/rl/off_policy/sac.cpp b/src/csrc/rl/off_policy/sac.cpp index 3654357c..2323449f 100644 --- a/src/csrc/rl/off_policy/sac.cpp +++ b/src/csrc/rl/off_policy/sac.cpp @@ -21,6 +21,7 @@ #include "internal/exceptions.h" #include "internal/rl/distributions.h" #include "internal/rl/off_policy/sac.h" +#include "internal/rl/setup.h" namespace torchfort { @@ -115,26 +116,8 @@ SACSystem::SACSystem(const char* name, const YAML::Node& system_node, int model_ } if (system_node["replay_buffer"]) { - auto rb_node = system_node["replay_buffer"]; - std::string rb_type = sanitize(rb_node["type"].as()); - if (rb_node["parameters"]) { - auto params = get_params(rb_node["parameters"]); - std::set supported_params{"type", "max_size", "min_size", "n_envs"}; - check_params(supported_params, params.keys()); - auto max_size = static_cast(params.get_param("max_size")[0]); - auto min_size = static_cast(params.get_param("min_size")[0]); - auto n_envs = static_cast(params.get_param("n_envs", 1)[0]); - - // distinction between buffer types - if (rb_type == "uniform") { - replay_buffer_ = std::make_shared(max_size, min_size, n_envs, gamma_, nstep_, - nstep_reward_reduction_, rb_device); - } else { - THROW_INVALID_USAGE(rb_type); - } - } else { - THROW_INVALID_USAGE("Missing parameters section in replay_buffer section in configuration file."); - } + replay_buffer_ = rl::get_replay_buffer(system_node["replay_buffer"], gamma_, nstep_, + nstep_reward_reduction_, rb_device); } else { THROW_INVALID_USAGE("Missing replay_buffer section in configuration file."); } diff --git a/src/csrc/rl/off_policy/td3.cpp b/src/csrc/rl/off_policy/td3.cpp index 09957054..0c23c2ca 100644 --- a/src/csrc/rl/off_policy/td3.cpp +++ b/src/csrc/rl/off_policy/td3.cpp @@ -20,6 +20,7 @@ #include "internal/exceptions.h" #include "internal/rl/off_policy/td3.h" +#include "internal/rl/setup.h" namespace torchfort { @@ -136,26 +137,8 @@ TD3System::TD3System(const char* name, const YAML::Node& system_node, int model_ } if (system_node["replay_buffer"]) { - auto rb_node = system_node["replay_buffer"]; - std::string rb_type = sanitize(rb_node["type"].as()); - if (rb_node["parameters"]) { - auto params = get_params(rb_node["parameters"]); - std::set supported_params{"type", "max_size", "min_size", "n_envs"}; - check_params(supported_params, params.keys()); - auto max_size = static_cast(params.get_param("max_size")[0]); - auto min_size = static_cast(params.get_param("min_size")[0]); - auto n_envs = static_cast(params.get_param("n_envs", 1)[0]); - - // distinction between buffer types - if (rb_type == "uniform") { - replay_buffer_ = std::make_shared(max_size, min_size, n_envs, gamma_, nstep_, - nstep_reward_reduction_, rb_device); - } else { - THROW_INVALID_USAGE(rb_type); - } - } else { - THROW_INVALID_USAGE("Missing parameters section in replay_buffer section in configuration file."); - } + replay_buffer_ = rl::get_replay_buffer(system_node["replay_buffer"], gamma_, nstep_, + nstep_reward_reduction_, rb_device); } else { THROW_INVALID_USAGE("Missing replay_buffer section in configuration file."); } diff --git a/src/csrc/rl/setup.cpp b/src/csrc/rl/setup.cpp new file mode 100644 index 00000000..f720a2f7 --- /dev/null +++ b/src/csrc/rl/setup.cpp @@ -0,0 +1,65 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/rl/setup.h" + +#include "internal/exceptions.h" +#include "internal/setup.h" + +namespace torchfort { + +namespace rl { + +std::shared_ptr get_replay_buffer(const YAML::Node& rb_node, float gamma, int nstep, + RewardReductionMode nstep_reward_reduction, int rb_device) { + if (!rb_node["type"]) { + THROW_INVALID_USAGE("Missing type field in replay_buffer section in configuration file."); + } + std::string rb_type = sanitize(rb_node["type"].as()); + + if (!rb_node["parameters"]) { + THROW_INVALID_USAGE("Missing parameters section in replay_buffer section in configuration file."); + } + auto params = get_params(rb_node["parameters"]); + + std::set supported_params{"type", "max_size", "min_size", "n_envs", + "alpha", "beta0", "beta_max", "beta_steps"}; + check_params(supported_params, params.keys()); + + auto max_size = static_cast(params.get_param("max_size")[0]); + auto min_size = static_cast(params.get_param("min_size")[0]); + auto n_envs = static_cast(params.get_param("n_envs", 1)[0]); + + if (rb_type == "uniform") { + return std::make_shared(max_size, min_size, n_envs, gamma, nstep, + nstep_reward_reduction, rb_device); + } else if (rb_type == "prioritized") { + float alpha = params.get_param("alpha", 0.6f)[0]; + float beta0 = params.get_param("beta0", 0.4f)[0]; + float beta_max = params.get_param("beta_max", 1.0f)[0]; + size_t beta_steps = static_cast(params.get_param("beta_steps", 100000)[0]); + return std::make_shared(max_size, min_size, n_envs, gamma, nstep, + nstep_reward_reduction, alpha, beta0, beta_max, + beta_steps, rb_device); + } else { + THROW_INVALID_USAGE("Unknown replay_buffer type: " + rb_type); + } +} + +} // namespace rl + +} // namespace torchfort diff --git a/tests/rl/test_replay_buffer.cpp b/tests/rl/test_replay_buffer.cpp index a2e94d9f..b99303a4 100644 --- a/tests/rl/test_replay_buffer.cpp +++ b/tests/rl/test_replay_buffer.cpp @@ -444,6 +444,244 @@ TEST(RewardNormalization, LargeScaleNormalizedToUnitStd) { << "Mean should be preserved as ~true_mean/true_std, not removed"; } +// ============================================================================= +// PrioritizedReplayBuffer tests +// ============================================================================= + +class PrioritizedBuffer : public testing::TestWithParam {}; + +std::shared_ptr getTestPERBuffer(int buffer_size, int n_envs = 1, float gamma = 0.95f, + int nstep = 1, float alpha = 0.6f, + float beta0 = 0.4f) { + torch::NoGradGuard no_grad; + + // beta_steps=0: no annealing, beta stays at beta0 for deterministic tests + auto rbuff = std::make_shared(buffer_size, buffer_size, n_envs, gamma, nstep, + rl::RewardReductionMode::Sum, alpha, beta0, 1.0f, 0, -1); + + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_int_distribution dist(1, 5); + + torch::Tensor state = torch::zeros({n_envs, 1}, torch::kFloat32); + for (int i = 0; i < buffer_size; ++i) { + auto action = torch::ones({n_envs, 1}, torch::kFloat32) * static_cast(dist(rng)); + auto state_p = state + action; + auto rtens = action.index({"...", 0}).clone(); + auto dtens = torch::zeros({n_envs}, torch::kFloat32); + rbuff->update(state, action, state_p, rtens, dtens); + state.copy_(state_p); + } + return rbuff; +} + +// Shape and dtype checks +TEST_P(PrioritizedBuffer, ShapeConsistency) { + torch::manual_seed(666); + unsigned int n_envs = GetParam(); + unsigned int batch_size = 32; + unsigned int buffer_size = 4 * batch_size; + + auto rbuff = getTestPERBuffer(buffer_size, n_envs); + + torch::Tensor s, a, sp, r, d, is_weights, indices; + std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(batch_size); + + EXPECT_EQ(s.dim(), 2); + EXPECT_EQ(a.dim(), 2); + EXPECT_EQ(sp.dim(), 2); + EXPECT_EQ(r.dim(), 1); + EXPECT_EQ(d.dim(), 1); + EXPECT_EQ(is_weights.dim(), 1); + EXPECT_EQ(indices.dim(), 1); + + EXPECT_EQ(s.size(0), batch_size); + EXPECT_EQ(a.size(0), batch_size); + EXPECT_EQ(sp.size(0), batch_size); + EXPECT_EQ(r.size(0), batch_size); + EXPECT_EQ(d.size(0), batch_size); + EXPECT_EQ(is_weights.size(0), batch_size); + EXPECT_EQ(indices.size(0), batch_size); + + EXPECT_EQ(is_weights.scalar_type(), torch::kFloat32); + EXPECT_EQ(indices.scalar_type(), torch::kLong); +} + +// Before any update_priorities, all entries have equal (max) priority so all +// weights must be exactly 1.0 — uniform sampling as a special case of PER. +TEST_P(PrioritizedBuffer, InitialWeightsUniform) { + torch::manual_seed(666); + unsigned int n_envs = GetParam(); + unsigned int batch_size = 32; + unsigned int buffer_size = 4 * batch_size; + + auto rbuff = getTestPERBuffer(buffer_size, n_envs); + + torch::Tensor s, a, sp, r, d, is_weights, indices; + std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(batch_size); + + EXPECT_FLOAT_EQ(is_weights.min().item(), 1.f); + EXPECT_FLOAT_EQ(is_weights.max().item(), 1.f); +} + +// After update_priorities with varied TD errors: weights must be in (0, 1] and +// the entry with the lowest priority must achieve the maximum weight of 1.0. +TEST_P(PrioritizedBuffer, WeightRange) { + torch::manual_seed(666); + unsigned int n_envs = GetParam(); + unsigned int batch_size = 32; + unsigned int buffer_size = 4 * batch_size; + + auto rbuff = getTestPERBuffer(buffer_size, n_envs); + + torch::Tensor s, a, sp, r, d, is_weights, indices; + std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(batch_size); + + // set half the sampled entries to high TD, half to low TD + auto td_errors = torch::cat({torch::ones(batch_size / 2) * 10.f, + torch::ones(batch_size / 2) * 0.01f}); + rbuff->update_priorities(indices, td_errors); + + // sample again to see updated weights + std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(batch_size); + + EXPECT_GT(is_weights.min().item(), 0.f); + EXPECT_FLOAT_EQ(is_weights.max().item(), 1.f); +} + +// Higher priority (larger TD error) must map to a smaller IS weight since +// high-priority entries are overrepresented in sampling. +TEST_P(PrioritizedBuffer, WeightOrdering) { + torch::manual_seed(42); + + // small buffer so we control exactly which positions exist + unsigned int n_envs = GetParam(); + unsigned int buffer_size = 4; + unsigned int batch_size = 200; + float alpha = 1.0f, beta0 = 1.0f; // full prioritization + full IS for clearest signal + + auto rbuff = getTestPERBuffer(buffer_size, n_envs, 0.95f, 1, alpha, beta0); + + // directly set known priorities: position 0 → high TD, position 1 → low TD + auto known_indices = torch::tensor({0L, 1L}); + auto known_td_errors = torch::tensor({100.f, 0.001f}); + rbuff->update_priorities(known_indices, known_td_errors); + + // sample a large batch and collect weights for positions 0 and 1 + torch::Tensor s, a, sp, r, d, is_weights, indices; + std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(batch_size); + + float sum_w0 = 0.f, sum_w1 = 0.f; + int cnt0 = 0, cnt1 = 0; + auto idx_acc = indices.accessor(); + auto w_acc = is_weights.accessor(); + for (int i = 0; i < batch_size; ++i) { + if (idx_acc[i] == 0) { sum_w0 += w_acc[i]; cnt0++; } + if (idx_acc[i] == 1) { sum_w1 += w_acc[i]; cnt1++; } + } + + // both positions must appear in a batch of 200 from a buffer of 4 + ASSERT_GT(cnt0, 0) << "position 0 (high priority) must appear in batch"; + ASSERT_GT(cnt1, 0) << "position 1 (low priority) must appear in batch"; + + float avg_w0 = sum_w0 / cnt0; + float avg_w1 = sum_w1 / cnt1; + + // high priority (pos 0) must have strictly lower IS weight than low priority (pos 1) + EXPECT_LT(avg_w0, avg_w1) << "high-priority entry must have lower IS weight"; + // the low-priority entry must achieve the maximum weight of 1.0 + EXPECT_FLOAT_EQ(avg_w1, 1.f); +} + +// Entries with high TD error should be sampled significantly more often than +// entries with low TD error. +TEST_P(PrioritizedBuffer, PriorityBias) { + torch::manual_seed(42); + + unsigned int n_envs = GetParam(); + unsigned int buffer_size = 50; + unsigned int batch_size = 2000; + float alpha = 1.0f; // full prioritization for a clear statistical signal + + auto rbuff = getTestPERBuffer(buffer_size, n_envs, 0.95f, 1, alpha, 0.0f); + + // positions 0–4: high TD (10.0), positions 5–49: low TD (0.001) + auto hi_idx = torch::arange(0, 5, torch::kLong); + auto lo_idx = torch::arange(5, 50, torch::kLong); + rbuff->update_priorities(hi_idx, torch::ones(5) * 10.f); + rbuff->update_priorities(lo_idx, torch::ones(45) * 0.001f); + + // sample a large batch and count how often high-priority positions appear + torch::Tensor s, a, sp, r, d, is_weights, indices; + std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(batch_size); + + int hi_count = ((indices < 5).sum()).item(); + float hi_fraction = static_cast(hi_count) / static_cast(batch_size); + + // With alpha=1: p_hi/(p_hi*5 + p_lo*45) ≈ 10*5/(10*5+0.001*45) ≈ 99.9% expected. + // Use a very conservative threshold of 80% to be robust to RNG variation. + EXPECT_GT(hi_fraction, 0.8f) << "high-priority entries should dominate sampling"; +} + +// Buffer data integrity: s + a = s', r = |a| (set up by getTestPERBuffer) +TEST_P(PrioritizedBuffer, EntryConsistency) { + torch::manual_seed(666); + unsigned int n_envs = GetParam(); + unsigned int batch_size = 32; + unsigned int buffer_size = 4 * batch_size; + + auto rbuff = getTestPERBuffer(buffer_size, n_envs); + + torch::Tensor s, a, sp, r, d, is_weights, indices; + float state_diff = 0.f, reward_diff = 0.f; + for (int i = 0; i < 4; ++i) { + std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(batch_size); + state_diff += torch::sum(torch::abs(s + a - sp)).item(); + reward_diff += torch::sum(torch::abs(a.index({"...", 0}) - r)).item(); + } + + EXPECT_FLOAT_EQ(state_diff, 0.f); + EXPECT_FLOAT_EQ(reward_diff, 0.f); +} + +// Save/load must preserve buffer contents and priority state so that the same +// IS weights are produced (given the same RNG seed) after a round-trip. +TEST_P(PrioritizedBuffer, SaveRestore) { + torch::manual_seed(666); + unsigned int n_envs = GetParam(); + unsigned int buffer_size = 32; + + auto rbuff = getTestPERBuffer(buffer_size, n_envs); + + // set varied priorities so the save is non-trivial + torch::Tensor s, a, sp, r, d, is_weights, indices; + std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(16); + auto td_errors = torch::rand(16) * 5.f + 0.01f; + rbuff->update_priorities(indices, td_errors); + + rbuff->setSeed(42); + torch::Tensor s_b, a_b, sp_b, r_b, d_b, w_b, idx_b; + std::tie(s_b, a_b, sp_b, r_b, d_b, w_b, idx_b) = rbuff->sample(16); + + // save → reset → load + rbuff->save("/tmp/per_buffer.pt"); + rbuff->reset(); + rbuff->load("/tmp/per_buffer.pt"); + + rbuff->setSeed(42); + torch::Tensor s_a, a_a, sp_a, r_a, d_a, w_a, idx_a; + std::tie(s_a, a_a, sp_a, r_a, d_a, w_a, idx_a) = rbuff->sample(16); + + // indices and weights must be identical after round-trip + EXPECT_FLOAT_EQ(torch::sum(torch::abs(w_b - w_a)).item(), 0.f); + EXPECT_FLOAT_EQ(torch::sum(torch::abs((idx_b - idx_a).to(torch::kFloat))).item(), 0.f); + // buffer data integrity + EXPECT_FLOAT_EQ(torch::sum(torch::abs(s_b - s_a)).item(), 0.f); + EXPECT_FLOAT_EQ(torch::sum(torch::abs(r_b - r_a)).item(), 0.f); +} + +INSTANTIATE_TEST_SUITE_P(MultiEnv, PrioritizedBuffer, testing::Range(1, 3), testing::PrintToStringParamName()); + int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv); From 3f37ba7897a99a0dedc37100fbe779a17664e083 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 28 Apr 2026 04:23:24 -0700 Subject: [PATCH 3/9] fixing lvalue issue Signed-off-by: Thorsten Kurth --- src/csrc/include/internal/rl/replay_buffer.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/csrc/include/internal/rl/replay_buffer.h b/src/csrc/include/internal/rl/replay_buffer.h index ed82a27d..90f464f7 100644 --- a/src/csrc/include/internal/rl/replay_buffer.h +++ b/src/csrc/include/internal/rl/replay_buffer.h @@ -548,8 +548,10 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f void update_priorities(torch::Tensor indices, torch::Tensor td_errors) override { torch::NoGradGuard no_grad; - auto idx_acc = indices.to(torch::kCPU).accessor(); - auto td_acc = td_errors.to(torch::kCPU).accessor(); + auto indices_cpu = indices.to(torch::kCPU).contiguous(); + auto td_errors_cpu = td_errors.to(torch::kCPU).contiguous(); + auto idx_acc = indices_cpu.accessor(); + auto td_acc = td_errors_cpu.accessor(); for (int64_t i = 0; i < idx_acc.size(0); ++i) { size_t pos = static_cast(idx_acc[i]); From 12eb806b1014f93d0bda54c22df1c56d7b0b8ca4 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 3 Jun 2026 00:47:43 -0700 Subject: [PATCH 4/9] formatting Signed-off-by: Thorsten Kurth --- CMakeLists.txt | 53 ++++---- README.md | 2 +- requirements.txt | 22 ++-- src/csrc/include/internal/rl/replay_buffer.h | 123 +++++++++---------- src/csrc/rl/off_policy/ddpg.cpp | 4 +- src/csrc/rl/off_policy/sac.cpp | 4 +- src/csrc/rl/off_policy/td3.cpp | 4 +- src/csrc/rl/setup.cpp | 21 ++-- tests/rl/test_replay_buffer.cpp | 60 ++++----- 9 files changed, 145 insertions(+), 148 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ce29b5e..a2e93fed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,10 +4,10 @@ if (NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE RelWithDebInfo) endif() -# https://github.com/NVIDIA/TorchFort/issues/3 +#https: // github.com/NVIDIA/TorchFort/issues/3 cmake_policy(SET CMP0057 NEW) -# User-defined build options +#User - defined build options set(TORCHFORT_CUDA_CC_LIST "70;80;90" CACHE STRING "List of CUDA compute capabilities to build torchfort for.") set(TORCHFORT_NCCL_ROOT CACHE STRING "Path to search for NCCL installation. Default NVIDA HPC SDK provided NCCL version if available.") set(TORCHFORT_YAML_CPP_ROOT CACHE STRING "Path to search for yaml-cpp installation.") @@ -16,7 +16,7 @@ option(TORCHFORT_BUILD_EXAMPLES "Build examples" OFF) option(TORCHFORT_BUILD_TESTS "Build tests" OFF) option(TORCHFORT_ENABLE_GPU "Enable GPU/CUDA support" ON) -# For backward-compatibility with existing variable +#For backward - compatibility with existing variable if (YAML_CPP_ROOT) set(TORCHFORT_YAML_CPP_ROOT ${YAML_CPP_ROOT}) endif() @@ -34,13 +34,12 @@ endif() project(torchfort LANGUAGES ${LANGS}) if (CMAKE_CXX_COMPILER_ID STREQUAL "NVHPC") - # __rdtsc() in torch not supported by nvc++. Use g++ for CXX files. +#__rdtsc() in torch not supported by nvc++.Use g++ for CXX files. message(FATAL_ERROR "TorchFort does not support compilation of C++ files with nvc++. " "Set CMAKE_CXX_COMPILER to g++ to proceed.") endif() - -# unit testing with gtest +#unit testing with gtest if (TORCHFORT_BUILD_TESTS) enable_testing() include(CTest) @@ -49,29 +48,29 @@ if (TORCHFORT_BUILD_TESTS) googletest URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip ) - # For Windows: Prevent overriding the parent project's compiler/linker settings +#For Windows : Prevent overriding the parent project's compiler/linker settings set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) include(GoogleTest) endif() -# MPI +#MPI find_package(MPI REQUIRED) -# CUDA +#CUDA if (TORCHFORT_ENABLE_GPU) find_package(CUDAToolkit REQUIRED) - # HPC SDK - # Locate and append NVHPC CMake configuration if available +#HPC SDK +#Locate and append NVHPC CMake configuration if available find_program(NVHPC_CXX_BIN "nvc++") if (NVHPC_CXX_BIN) string(REPLACE "compilers/bin/nvc++" "cmake" NVHPC_CMAKE_DIR ${NVHPC_CXX_BIN}) set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH};${NVHPC_CMAKE_DIR}") find_package(NVHPC COMPONENTS "") endif() - - # Get NCCL library (with optional override) + +#Get NCCL library(with optional override) if (TORCHFORT_NCCL_ROOT) find_path(NCCL_INCLUDE_DIR REQUIRED NAMES nccl.h @@ -103,8 +102,8 @@ if (TORCHFORT_ENABLE_GPU) message(STATUS "Using NCCL library: ${NCCL_LIBRARY}") - # PyTorch - # Set TORCH_CUDA_ARCH_LIST string to match TORCHFORT_CUDA_CC_LIST +#PyTorch +#Set TORCH_CUDA_ARCH_LIST string to match TORCHFORT_CUDA_CC_LIST foreach(CUDA_CC ${TORCHFORT_CUDA_CC_LIST}) string(REGEX REPLACE "([0-9])$" ".\\1" CUDA_CC_W_DOT ${CUDA_CC}) list(APPEND TORCH_CUDA_ARCH_LIST ${CUDA_CC_W_DOT}) @@ -114,15 +113,15 @@ endif() find_package(Torch REQUIRED) -# Generate configuration header +#Generate configuration header configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/include/torchfort_config.h.in ${CMAKE_BINARY_DIR}/include/torchfort_config.h @ONLY ) -# yaml-cpp -#find_package(yaml-cpp REQUIRED) +#yaml - cpp +#find_package(yaml - cpp REQUIRED) find_path(YAML_CPP_INCLUDE_DIR REQUIRED NAMES yaml-cpp/yaml.h HINTS ${TORCHFORT_YAML_CPP_ROOT}/include @@ -133,7 +132,7 @@ find_library(YAML_CPP_LIBRARY REQUIRED ) message(STATUS "Using yaml-cpp library: ${YAML_CPP_LIBRARY}") -# C/C++ shared library +#C / C++ shared library add_library(${PROJECT_NAME} SHARED) set_target_properties(${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) @@ -220,17 +219,17 @@ install( INCLUDES DESTINATION ${CMAKE_INSTALL_PREFIX}/include ) -# Install generated configuration header +#Install generated configuration header install( FILES ${CMAKE_BINARY_DIR}/include/torchfort_config.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include ) -# Fortran library and module +#Fortran library and module if (TORCHFORT_BUILD_FORTRAN) if (CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC") - # Creating -gpu argument string for nvfortran GPU compilation +#Creating - gpu argument string for nvfortran GPU compilation foreach(CUDA_CC ${TORCHFORT_CUDA_CC_LIST}) list(APPEND CUF_GPU_ARG "cc${CUDA_CC}") endforeach() @@ -257,17 +256,17 @@ if (TORCHFORT_BUILD_FORTRAN) install( TARGETS "${PROJECT_NAME}_fort" ) - # install Fortran module +#install Fortran module install(FILES ${CMAKE_BINARY_DIR}/include/torchfort.mod DESTINATION ${CMAKE_INSTALL_PREFIX}/include) endif() -# install Python files +#install Python files install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/python/wandb_helper.py DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/python) -# install docs +#install docs install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/docs DESTINATION ${CMAKE_INSTALL_PREFIX}) -# build examples +#build examples if (TORCHFORT_BUILD_EXAMPLES) add_subdirectory(examples/cpp/cart_pole) if (TORCHFORT_BUILD_FORTRAN) @@ -276,7 +275,7 @@ if (TORCHFORT_BUILD_EXAMPLES) endif() endif() -# build tests +#build tests if (TORCHFORT_BUILD_TESTS) add_subdirectory(tests/general) add_subdirectory(tests/supervised) diff --git a/README.md b/README.md index a0a1c405..3f741e01 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# TorchFort +#TorchFort An Online Deep Learning Interface for HPC programs on NVIDIA GPUs diff --git a/requirements.txt b/requirements.txt index 68acbbdf..1b08c977 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,15 @@ -# basic packages -ruamel-yaml +#basic packages +ruamel - yaml -# pytorch and some dependencies -torch==2.8.0 +#pytorch and some dependencies + torch == + 2.8.0 -# training monitoring -wandb +#training monitoring + wandb -# RL example visualization related -pygame -moviepy +#RL example visualization related + pygame moviepy -# Supervised learning example visualization related -matplotlib +#Supervised learning example visualization related + matplotlib diff --git a/src/csrc/include/internal/rl/replay_buffer.h b/src/csrc/include/internal/rl/replay_buffer.h index 90f464f7..ec63dc0e 100644 --- a/src/csrc/include/internal/rl/replay_buffer.h +++ b/src/csrc/include/internal/rl/replay_buffer.h @@ -30,8 +30,8 @@ namespace torchfort { namespace rl { using BufferEntry = std::tuple; -using SampleResult = std::tuple; +using SampleResult = + std::tuple; enum RewardReductionMode { Sum = 1, Mean = 2, WeightedMean = 3, SumNoSkip = 4, MeanNoSkip = 5, WeightedMeanNoSkip = 6 }; @@ -380,9 +380,8 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f size_t beta_steps, int device) : ReplayBuffer(max_size, min_size, n_envs, device), gamma_(gamma), nstep_(nstep), alpha_(alpha), beta_(beta0), beta_max_(beta_max), - beta_increment_(beta_steps > 0 ? (beta_max - beta0) / static_cast(beta_steps) : 0.f), - epsilon_(1e-6f), max_priority_(1.f), min_p_alpha_(std::numeric_limits::max()), - write_pos_(0), current_size_(0) { + beta_increment_(beta_steps > 0 ? (beta_max - beta0) / static_cast(beta_steps) : 0.f), epsilon_(1e-6f), + max_priority_(1.f), min_p_alpha_(std::numeric_limits::max()), write_pos_(0), current_size_(0) { skip_incomplete_steps_ = true; if (reward_reduction_mode == RewardReductionMode::MeanNoSkip) { @@ -410,19 +409,17 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f torch::NoGradGuard no_grad; if ((s.sizes()[0] != n_envs_) || (a.sizes()[0] != n_envs_) || (sp.sizes()[0] != n_envs_)) { - throw std::runtime_error( - "PrioritizedReplayBuffer::update: leading dimension of s, a, sp must equal n_envs"); + throw std::runtime_error("PrioritizedReplayBuffer::update: leading dimension of s, a, sp must equal n_envs"); } if ((r.sizes()[0] != n_envs_) || (d.sizes()[0] != n_envs_)) { - throw std::runtime_error( - "PrioritizedReplayBuffer::update: leading dimension of r, d must equal n_envs"); + throw std::runtime_error("PrioritizedReplayBuffer::update: leading dimension of r, d must equal n_envs"); } - auto sc = s.to(device_, s.dtype(), false, true); - auto ac = a.to(device_, a.dtype(), false, true); + auto sc = s.to(device_, s.dtype(), false, true); + auto ac = a.to(device_, a.dtype(), false, true); auto spc = sp.to(device_, sp.dtype(), false, true); - auto rc = r.to(device_, r.dtype(), false, true); - auto dc = d.to(device_, d.dtype(), false, true); + auto rc = r.to(device_, r.dtype(), false, true); + auto dc = d.to(device_, d.dtype(), false, true); // write entry; zero priority until it becomes a valid n-step starting point size_t pos = write_pos_; @@ -446,19 +443,19 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f SampleResult sample(int batch_size) override { torch::NoGradGuard no_grad; - auto stens_list = std::vector(batch_size); - auto atens_list = std::vector(batch_size); + auto stens_list = std::vector(batch_size); + auto atens_list = std::vector(batch_size); auto sptens_list = std::vector(batch_size); - auto rtens_list = std::vector(batch_size); - auto dtens_list = std::vector(batch_size); - auto wtens_list = std::vector(batch_size); - auto itens_list = std::vector(batch_size); + auto rtens_list = std::vector(batch_size); + auto dtens_list = std::vector(batch_size); + auto wtens_list = std::vector(batch_size); + auto itens_list = std::vector(batch_size); // anneal beta towards beta_max_ beta_ = std::min(beta_max_, beta_ + beta_increment_); - float total = treeTotal_(); - float segment = total / static_cast(batch_size); + float total = treeTotal_(); + float segment = total / static_cast(batch_size); // minimum p^alpha currently in the buffer — normalises IS weights so max weight = 1 float min_p_alpha = min_p_alpha_; @@ -486,24 +483,24 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f // extract (state, action, ...) at (pos, env_idx) torch::Tensor stens, atens, sptens, rtens, dtens; std::tie(stens, atens, sptens, rtens, dtens) = buffer_[pos]; - stens_list[s] = stens.index({env_idx, "..."}).clone(); - atens_list[s] = atens.index({env_idx, "..."}).clone(); + stens_list[s] = stens.index({env_idx, "..."}).clone(); + atens_list[s] = atens.index({env_idx, "..."}).clone(); sptens_list[s] = sptens.index({env_idx, "..."}).clone(); - rtens_list[s] = rtens.index({env_idx}).clone(); - dtens_list[s] = dtens.index({env_idx}).clone(); + rtens_list[s] = rtens.index({env_idx}).clone(); + dtens_list[s] = dtens.index({env_idx}).clone(); // n-step rollout — identical logic to UniformReplayBuffer, but uses modular indexing float r_norm = 1.f; - int r_count = 1; - bool skip = false; + int r_count = 1; + bool skip = false; torch::Tensor deff = 1.f - dtens_list[s]; for (int off = 1; off < nstep_; ++off) { size_t next_pos = (pos + static_cast(off)) % max_size_; std::tie(stens, atens, sptens, rtens, dtens) = buffer_[next_pos]; sptens_list[s] = sptens.index({env_idx, "..."}).clone(); float gamma_eff = static_cast(std::pow(gamma_, off)); - rtens_list[s] = rtens_list[s] + gamma_eff * rtens.index({env_idx}); - r_norm += gamma_eff; + rtens_list[s] = rtens_list[s] + gamma_eff * rtens.index({env_idx}); + r_norm += gamma_eff; r_count++; float d_val = dtens.index({env_idx}).item(); @@ -532,7 +529,7 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f } auto float_opts = torch::TensorOptions().dtype(torch::kFloat32).device(device_); - auto long_opts = torch::TensorOptions().dtype(torch::kLong).device(device_); + auto long_opts = torch::TensorOptions().dtype(torch::kLong).device(device_); wtens_list[s] = torch::tensor(weight, float_opts); itens_list[s] = torch::tensor(static_cast(pos), long_opts); @@ -551,23 +548,23 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f auto indices_cpu = indices.to(torch::kCPU).contiguous(); auto td_errors_cpu = td_errors.to(torch::kCPU).contiguous(); auto idx_acc = indices_cpu.accessor(); - auto td_acc = td_errors_cpu.accessor(); + auto td_acc = td_errors_cpu.accessor(); for (int64_t i = 0; i < idx_acc.size(0); ++i) { - size_t pos = static_cast(idx_acc[i]); - float raw_p = std::abs(td_acc[i]) + epsilon_; + size_t pos = static_cast(idx_acc[i]); + float raw_p = std::abs(td_acc[i]) + epsilon_; float p_alpha = std::pow(raw_p, alpha_); treeUpdate_(pos, p_alpha); priorities_[pos] = raw_p; max_priority_ = std::max(max_priority_, raw_p); - min_p_alpha_ = std::min(min_p_alpha_, p_alpha); + min_p_alpha_ = std::min(min_p_alpha_, p_alpha); } } BufferEntry get(int index) override { if (index < 0 || static_cast(index) >= current_size_) { - throw std::runtime_error("PrioritizedReplayBuffer::get: index " + std::to_string(index) + - " out of bounds [0, " + std::to_string(current_size_) + ")."); + throw std::runtime_error("PrioritizedReplayBuffer::get: index " + std::to_string(index) + " out of bounds [0, " + + std::to_string(current_size_) + ")."); } return buffer_[index]; } @@ -576,10 +573,10 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f void reset() override { current_size_ = 0; - write_pos_ = 0; + write_pos_ = 0; max_priority_ = 1.f; - min_p_alpha_ = std::numeric_limits::max(); - std::fill(sum_tree_.begin(), sum_tree_.end(), 0.f); + min_p_alpha_ = std::numeric_limits::max(); + std::fill(sum_tree_.begin(), sum_tree_.end(), 0.f); std::fill(priorities_.begin(), priorities_.end(), 0.f); } @@ -616,26 +613,24 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f if (!std::filesystem::exists(root_dir)) { bool rv = std::filesystem::create_directory(root_dir); if (!rv) { - throw std::runtime_error("PrioritizedReplayBuffer::save: unable to create directory " + - root_dir.native() + "."); + throw std::runtime_error("PrioritizedReplayBuffer::save: unable to create directory " + root_dir.native() + + "."); } } - torch::save(s_data, root_dir / "s_data.pt"); - torch::save(a_data, root_dir / "a_data.pt"); + torch::save(s_data, root_dir / "s_data.pt"); + torch::save(a_data, root_dir / "a_data.pt"); torch::save(sp_data, root_dir / "sp_data.pt"); - torch::save(r_data, root_dir / "r_data.pt"); - torch::save(d_data, root_dir / "d_data.pt"); + torch::save(r_data, root_dir / "r_data.pt"); + torch::save(d_data, root_dir / "d_data.pt"); // save raw priorities (before alpha) for positions 0..current_size_-1 - auto prio_tensor = torch::tensor( - std::vector(priorities_.begin(), priorities_.begin() + current_size_)); + auto prio_tensor = torch::tensor(std::vector(priorities_.begin(), priorities_.begin() + current_size_)); torch::save({prio_tensor}, root_dir / "priorities.pt"); // save scalar state: [write_pos, current_size, beta, max_priority, min_p_alpha] - auto state_tensor = torch::tensor(std::vector{static_cast(write_pos_), - static_cast(current_size_), beta_, - max_priority_, min_p_alpha_}); + auto state_tensor = torch::tensor(std::vector{ + static_cast(write_pos_), static_cast(current_size_), beta_, max_priority_, min_p_alpha_}); torch::save({state_tensor}, root_dir / "per_state.pt"); } @@ -643,20 +638,20 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f std::filesystem::path root_dir(fname); std::vector s_data, a_data, sp_data, r_data, d_data; - torch::load(s_data, root_dir / "s_data.pt"); - torch::load(a_data, root_dir / "a_data.pt"); + torch::load(s_data, root_dir / "s_data.pt"); + torch::load(a_data, root_dir / "a_data.pt"); torch::load(sp_data, root_dir / "sp_data.pt"); - torch::load(r_data, root_dir / "r_data.pt"); - torch::load(d_data, root_dir / "d_data.pt"); + torch::load(r_data, root_dir / "r_data.pt"); + torch::load(d_data, root_dir / "d_data.pt"); // restore scalar state std::vector state_vec; torch::load(state_vec, root_dir / "per_state.pt"); - write_pos_ = static_cast(state_vec[0][0].item()); + write_pos_ = static_cast(state_vec[0][0].item()); current_size_ = static_cast(state_vec[0][1].item()); - beta_ = state_vec[0][2].item(); + beta_ = state_vec[0][2].item(); max_priority_ = state_vec[0][3].item(); - min_p_alpha_ = state_vec[0][4].item(); + min_p_alpha_ = state_vec[0][4].item(); // restore buffer entries at their physical positions buffer_.assign(max_size_, BufferEntry{}); @@ -668,7 +663,7 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f std::vector prio_vec; torch::load(prio_vec, root_dir / "priorities.pt"); std::fill(priorities_.begin(), priorities_.end(), 0.f); - std::fill(sum_tree_.begin(), sum_tree_.end(), 0.f); + std::fill(sum_tree_.begin(), sum_tree_.end(), 0.f); for (size_t i = 0; i < current_size_; ++i) { float raw_p = prio_vec[0][static_cast(i)].item(); priorities_[i] = raw_p; @@ -718,13 +713,13 @@ class PrioritizedReplayBuffer : public ReplayBuffer, public std::enable_shared_f size_t write_pos_; size_t current_size_; - float alpha_; // priority exponent - float beta_; // IS weight exponent (annealed from beta0 toward beta_max_) + float alpha_; // priority exponent + float beta_; // IS weight exponent (annealed from beta0 toward beta_max_) float beta_max_; - float beta_increment_; // added to beta_ each call to sample() - float epsilon_; // priority floor (prevents zero probability) - float max_priority_; // maximum raw priority seen; assigned to new entries - float min_p_alpha_; // minimum p^alpha currently tracked; used to normalise IS weights + float beta_increment_; // added to beta_ each call to sample() + float epsilon_; // priority floor (prevents zero probability) + float max_priority_; // maximum raw priority seen; assigned to new entries + float min_p_alpha_; // minimum p^alpha currently tracked; used to normalise IS weights std::mt19937_64 rng_; float gamma_; diff --git a/src/csrc/rl/off_policy/ddpg.cpp b/src/csrc/rl/off_policy/ddpg.cpp index 6f3cad5f..d9028dfe 100644 --- a/src/csrc/rl/off_policy/ddpg.cpp +++ b/src/csrc/rl/off_policy/ddpg.cpp @@ -113,8 +113,8 @@ DDPGSystem::DDPGSystem(const char* name, const YAML::Node& system_node, int mode } if (system_node["replay_buffer"]) { - replay_buffer_ = rl::get_replay_buffer(system_node["replay_buffer"], gamma_, nstep_, - nstep_reward_reduction_, rb_device); + replay_buffer_ = + rl::get_replay_buffer(system_node["replay_buffer"], gamma_, nstep_, nstep_reward_reduction_, rb_device); } else { THROW_INVALID_USAGE("Missing replay_buffer section in configuration file."); } diff --git a/src/csrc/rl/off_policy/sac.cpp b/src/csrc/rl/off_policy/sac.cpp index 2323449f..da00eda9 100644 --- a/src/csrc/rl/off_policy/sac.cpp +++ b/src/csrc/rl/off_policy/sac.cpp @@ -116,8 +116,8 @@ SACSystem::SACSystem(const char* name, const YAML::Node& system_node, int model_ } if (system_node["replay_buffer"]) { - replay_buffer_ = rl::get_replay_buffer(system_node["replay_buffer"], gamma_, nstep_, - nstep_reward_reduction_, rb_device); + replay_buffer_ = + rl::get_replay_buffer(system_node["replay_buffer"], gamma_, nstep_, nstep_reward_reduction_, rb_device); } else { THROW_INVALID_USAGE("Missing replay_buffer section in configuration file."); } diff --git a/src/csrc/rl/off_policy/td3.cpp b/src/csrc/rl/off_policy/td3.cpp index 0c23c2ca..e44f132d 100644 --- a/src/csrc/rl/off_policy/td3.cpp +++ b/src/csrc/rl/off_policy/td3.cpp @@ -137,8 +137,8 @@ TD3System::TD3System(const char* name, const YAML::Node& system_node, int model_ } if (system_node["replay_buffer"]) { - replay_buffer_ = rl::get_replay_buffer(system_node["replay_buffer"], gamma_, nstep_, - nstep_reward_reduction_, rb_device); + replay_buffer_ = + rl::get_replay_buffer(system_node["replay_buffer"], gamma_, nstep_, nstep_reward_reduction_, rb_device); } else { THROW_INVALID_USAGE("Missing replay_buffer section in configuration file."); } diff --git a/src/csrc/rl/setup.cpp b/src/csrc/rl/setup.cpp index f720a2f7..0dcd7f8e 100644 --- a/src/csrc/rl/setup.cpp +++ b/src/csrc/rl/setup.cpp @@ -36,25 +36,24 @@ std::shared_ptr get_replay_buffer(const YAML::Node& rb_node, float } auto params = get_params(rb_node["parameters"]); - std::set supported_params{"type", "max_size", "min_size", "n_envs", - "alpha", "beta0", "beta_max", "beta_steps"}; + std::set supported_params{"type", "max_size", "min_size", "n_envs", + "alpha", "beta0", "beta_max", "beta_steps"}; check_params(supported_params, params.keys()); auto max_size = static_cast(params.get_param("max_size")[0]); auto min_size = static_cast(params.get_param("min_size")[0]); - auto n_envs = static_cast(params.get_param("n_envs", 1)[0]); + auto n_envs = static_cast(params.get_param("n_envs", 1)[0]); if (rb_type == "uniform") { - return std::make_shared(max_size, min_size, n_envs, gamma, nstep, - nstep_reward_reduction, rb_device); + return std::make_shared(max_size, min_size, n_envs, gamma, nstep, nstep_reward_reduction, + rb_device); } else if (rb_type == "prioritized") { - float alpha = params.get_param("alpha", 0.6f)[0]; - float beta0 = params.get_param("beta0", 0.4f)[0]; - float beta_max = params.get_param("beta_max", 1.0f)[0]; + float alpha = params.get_param("alpha", 0.6f)[0]; + float beta0 = params.get_param("beta0", 0.4f)[0]; + float beta_max = params.get_param("beta_max", 1.0f)[0]; size_t beta_steps = static_cast(params.get_param("beta_steps", 100000)[0]); - return std::make_shared(max_size, min_size, n_envs, gamma, nstep, - nstep_reward_reduction, alpha, beta0, beta_max, - beta_steps, rb_device); + return std::make_shared(max_size, min_size, n_envs, gamma, nstep, nstep_reward_reduction, + alpha, beta0, beta_max, beta_steps, rb_device); } else { THROW_INVALID_USAGE("Unknown replay_buffer type: " + rb_type); } diff --git a/tests/rl/test_replay_buffer.cpp b/tests/rl/test_replay_buffer.cpp index b99303a4..0f52ad1f 100644 --- a/tests/rl/test_replay_buffer.cpp +++ b/tests/rl/test_replay_buffer.cpp @@ -451,13 +451,12 @@ TEST(RewardNormalization, LargeScaleNormalizedToUnitStd) { class PrioritizedBuffer : public testing::TestWithParam {}; std::shared_ptr getTestPERBuffer(int buffer_size, int n_envs = 1, float gamma = 0.95f, - int nstep = 1, float alpha = 0.6f, - float beta0 = 0.4f) { + int nstep = 1, float alpha = 0.6f, float beta0 = 0.4f) { torch::NoGradGuard no_grad; // beta_steps=0: no annealing, beta stays at beta0 for deterministic tests auto rbuff = std::make_shared(buffer_size, buffer_size, n_envs, gamma, nstep, - rl::RewardReductionMode::Sum, alpha, beta0, 1.0f, 0, -1); + rl::RewardReductionMode::Sum, alpha, beta0, 1.0f, 0, -1); std::random_device dev; std::mt19937 rng(dev()); @@ -465,10 +464,10 @@ std::shared_ptr getTestPERBuffer(int buffer_size, i torch::Tensor state = torch::zeros({n_envs, 1}, torch::kFloat32); for (int i = 0; i < buffer_size; ++i) { - auto action = torch::ones({n_envs, 1}, torch::kFloat32) * static_cast(dist(rng)); + auto action = torch::ones({n_envs, 1}, torch::kFloat32) * static_cast(dist(rng)); auto state_p = state + action; - auto rtens = action.index({"...", 0}).clone(); - auto dtens = torch::zeros({n_envs}, torch::kFloat32); + auto rtens = action.index({"...", 0}).clone(); + auto dtens = torch::zeros({n_envs}, torch::kFloat32); rbuff->update(state, action, state_p, rtens, dtens); state.copy_(state_p); } @@ -478,7 +477,7 @@ std::shared_ptr getTestPERBuffer(int buffer_size, i // Shape and dtype checks TEST_P(PrioritizedBuffer, ShapeConsistency) { torch::manual_seed(666); - unsigned int n_envs = GetParam(); + unsigned int n_envs = GetParam(); unsigned int batch_size = 32; unsigned int buffer_size = 4 * batch_size; @@ -511,7 +510,7 @@ TEST_P(PrioritizedBuffer, ShapeConsistency) { // weights must be exactly 1.0 — uniform sampling as a special case of PER. TEST_P(PrioritizedBuffer, InitialWeightsUniform) { torch::manual_seed(666); - unsigned int n_envs = GetParam(); + unsigned int n_envs = GetParam(); unsigned int batch_size = 32; unsigned int buffer_size = 4 * batch_size; @@ -528,7 +527,7 @@ TEST_P(PrioritizedBuffer, InitialWeightsUniform) { // the entry with the lowest priority must achieve the maximum weight of 1.0. TEST_P(PrioritizedBuffer, WeightRange) { torch::manual_seed(666); - unsigned int n_envs = GetParam(); + unsigned int n_envs = GetParam(); unsigned int batch_size = 32; unsigned int buffer_size = 4 * batch_size; @@ -538,8 +537,7 @@ TEST_P(PrioritizedBuffer, WeightRange) { std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(batch_size); // set half the sampled entries to high TD, half to low TD - auto td_errors = torch::cat({torch::ones(batch_size / 2) * 10.f, - torch::ones(batch_size / 2) * 0.01f}); + auto td_errors = torch::cat({torch::ones(batch_size / 2) * 10.f, torch::ones(batch_size / 2) * 0.01f}); rbuff->update_priorities(indices, td_errors); // sample again to see updated weights @@ -555,15 +553,15 @@ TEST_P(PrioritizedBuffer, WeightOrdering) { torch::manual_seed(42); // small buffer so we control exactly which positions exist - unsigned int n_envs = GetParam(); + unsigned int n_envs = GetParam(); unsigned int buffer_size = 4; - unsigned int batch_size = 200; - float alpha = 1.0f, beta0 = 1.0f; // full prioritization + full IS for clearest signal + unsigned int batch_size = 200; + float alpha = 1.0f, beta0 = 1.0f; // full prioritization + full IS for clearest signal auto rbuff = getTestPERBuffer(buffer_size, n_envs, 0.95f, 1, alpha, beta0); // directly set known priorities: position 0 → high TD, position 1 → low TD - auto known_indices = torch::tensor({0L, 1L}); + auto known_indices = torch::tensor({0L, 1L}); auto known_td_errors = torch::tensor({100.f, 0.001f}); rbuff->update_priorities(known_indices, known_td_errors); @@ -574,10 +572,16 @@ TEST_P(PrioritizedBuffer, WeightOrdering) { float sum_w0 = 0.f, sum_w1 = 0.f; int cnt0 = 0, cnt1 = 0; auto idx_acc = indices.accessor(); - auto w_acc = is_weights.accessor(); + auto w_acc = is_weights.accessor(); for (int i = 0; i < batch_size; ++i) { - if (idx_acc[i] == 0) { sum_w0 += w_acc[i]; cnt0++; } - if (idx_acc[i] == 1) { sum_w1 += w_acc[i]; cnt1++; } + if (idx_acc[i] == 0) { + sum_w0 += w_acc[i]; + cnt0++; + } + if (idx_acc[i] == 1) { + sum_w1 += w_acc[i]; + cnt1++; + } } // both positions must appear in a batch of 200 from a buffer of 4 @@ -598,17 +602,17 @@ TEST_P(PrioritizedBuffer, WeightOrdering) { TEST_P(PrioritizedBuffer, PriorityBias) { torch::manual_seed(42); - unsigned int n_envs = GetParam(); + unsigned int n_envs = GetParam(); unsigned int buffer_size = 50; - unsigned int batch_size = 2000; - float alpha = 1.0f; // full prioritization for a clear statistical signal + unsigned int batch_size = 2000; + float alpha = 1.0f; // full prioritization for a clear statistical signal auto rbuff = getTestPERBuffer(buffer_size, n_envs, 0.95f, 1, alpha, 0.0f); // positions 0–4: high TD (10.0), positions 5–49: low TD (0.001) - auto hi_idx = torch::arange(0, 5, torch::kLong); + auto hi_idx = torch::arange(0, 5, torch::kLong); auto lo_idx = torch::arange(5, 50, torch::kLong); - rbuff->update_priorities(hi_idx, torch::ones(5) * 10.f); + rbuff->update_priorities(hi_idx, torch::ones(5) * 10.f); rbuff->update_priorities(lo_idx, torch::ones(45) * 0.001f); // sample a large batch and count how often high-priority positions appear @@ -626,7 +630,7 @@ TEST_P(PrioritizedBuffer, PriorityBias) { // Buffer data integrity: s + a = s', r = |a| (set up by getTestPERBuffer) TEST_P(PrioritizedBuffer, EntryConsistency) { torch::manual_seed(666); - unsigned int n_envs = GetParam(); + unsigned int n_envs = GetParam(); unsigned int batch_size = 32; unsigned int buffer_size = 4 * batch_size; @@ -636,11 +640,11 @@ TEST_P(PrioritizedBuffer, EntryConsistency) { float state_diff = 0.f, reward_diff = 0.f; for (int i = 0; i < 4; ++i) { std::tie(s, a, sp, r, d, is_weights, indices) = rbuff->sample(batch_size); - state_diff += torch::sum(torch::abs(s + a - sp)).item(); + state_diff += torch::sum(torch::abs(s + a - sp)).item(); reward_diff += torch::sum(torch::abs(a.index({"...", 0}) - r)).item(); } - EXPECT_FLOAT_EQ(state_diff, 0.f); + EXPECT_FLOAT_EQ(state_diff, 0.f); EXPECT_FLOAT_EQ(reward_diff, 0.f); } @@ -648,7 +652,7 @@ TEST_P(PrioritizedBuffer, EntryConsistency) { // IS weights are produced (given the same RNG seed) after a round-trip. TEST_P(PrioritizedBuffer, SaveRestore) { torch::manual_seed(666); - unsigned int n_envs = GetParam(); + unsigned int n_envs = GetParam(); unsigned int buffer_size = 32; auto rbuff = getTestPERBuffer(buffer_size, n_envs); @@ -673,7 +677,7 @@ TEST_P(PrioritizedBuffer, SaveRestore) { std::tie(s_a, a_a, sp_a, r_a, d_a, w_a, idx_a) = rbuff->sample(16); // indices and weights must be identical after round-trip - EXPECT_FLOAT_EQ(torch::sum(torch::abs(w_b - w_a)).item(), 0.f); + EXPECT_FLOAT_EQ(torch::sum(torch::abs(w_b - w_a)).item(), 0.f); EXPECT_FLOAT_EQ(torch::sum(torch::abs((idx_b - idx_a).to(torch::kFloat))).item(), 0.f); // buffer data integrity EXPECT_FLOAT_EQ(torch::sum(torch::abs(s_b - s_a)).item(), 0.f); From 33fd89765e8adb14f33a00a96e88dd8aa85bf7ae Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 3 Jun 2026 01:02:46 -0700 Subject: [PATCH 5/9] adding description of PER to docs Signed-off-by: Thorsten Kurth --- docs/api/config.rst | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/api/config.rst b/docs/api/config.rst index a2992fc1..7e543760 100644 --- a/docs/api/config.rst +++ b/docs/api/config.rst @@ -403,7 +403,7 @@ The block in the configuration file defining algorithm properties takes the foll parameters: