Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ target_sources(${PROJECT_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/rl/sac_model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/policy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/running_normalizer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/setup.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/off_policy/interface.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/off_policy/ddpg.cpp
Expand Down
23 changes: 22 additions & 1 deletion docs/api/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ The block in the configuration file defining algorithm properties takes the foll
parameters:
<option> = <value>

Currently, only type ``uniform`` is supported. The following table lists the available options:
Currently, types ``uniform`` and ``prioritized`` are supported. The following table lists the available options:

+---------------------------+-----------------+-----------------+------------------------------------------------------------------+
| Replay Buffer Type | Option | Data Type | Description |
Expand All @@ -414,11 +414,32 @@ Currently, only type ``uniform`` is supported. The following table lists the ava
+ +-----------------+-----------------+------------------------------------------------------------------+
| | ``n_envs`` | integer | Number of environments |
+---------------------------+-----------------+-----------------+------------------------------------------------------------------+
| ``prioritized`` | ``min_size`` | integer | Minimum number of samples before buffer is ready for training |
+ +-----------------+-----------------+------------------------------------------------------------------+
| | ``max_size`` | integer | Maximum capacity |
+ +-----------------+-----------------+------------------------------------------------------------------+
| | ``n_envs`` | integer | Number of environments |
+ +-----------------+-----------------+------------------------------------------------------------------+
| | ``alpha`` | float | Prioritization exponent; 0=uniform, 1=full (default 0.6) |
+ +-----------------+-----------------+------------------------------------------------------------------+
| | ``beta0`` | float | Initial importance-sampling weight exponent (default 0.4) |
+ +-----------------+-----------------+------------------------------------------------------------------+
| | ``beta_max`` | float | Final importance-sampling weight exponent (default 1.0) |
+ +-----------------+-----------------+------------------------------------------------------------------+
| | ``beta_steps`` | integer | Steps to anneal beta from beta0 to beta_max (default 100000) |
+---------------------------+-----------------+-----------------+------------------------------------------------------------------+

Note that the effective sizes for each environment is :math:`\mathrm{min\_size} / \mathrm{n\_envs}` and :math:`\mathrm{max\_size} / \mathrm{n\_envs}`.
You need to ensure that you can store at least one sample for each environment. However, for better algorithm performance, it is highly advised to provide buffers
which can store longer trajectories.

The ``prioritized`` buffer implements Prioritized Experience Replay (`Schaul et al., 2016 <https://arxiv.org/abs/1511.05952>`_), sampling
transitions in proportion to their last observed temporal-difference (TD) error rather than uniformly. The degree of prioritization is controlled
by ``alpha`` (with ``alpha = 0`` recovering uniform sampling), and the resulting sampling bias is corrected by importance-sampling weights whose
exponent ``beta`` is annealed linearly from ``beta0`` to ``beta_max`` over ``beta_steps`` sampling steps. All off-policy algorithms (``DDPG``,
``TD3``, ``SAC``) transparently apply these importance-sampling weights to their losses and feed the per-sample TD errors back to update the
priorities; no changes to the algorithm configuration are required to switch between ``uniform`` and ``prioritized`` buffers.

For on-policy algorithms, the block looks as follows:

.. code-block:: yaml
Expand Down
12 changes: 5 additions & 7 deletions src/csrc/include/internal/rl/off_policy/ddpg.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ template <typename T>
void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const ModelPack& q_model,
const ModelPack& q_model_target, torch::Tensor state_old_tensor, torch::Tensor state_new_tensor,
torch::Tensor action_old_tensor, torch::Tensor action_new_tensor, torch::Tensor reward_tensor,
torch::Tensor d_tensor, const T& gamma, const T& rho, T& p_loss_val, T& q_loss_val) {
torch::Tensor d_tensor, torch::Tensor is_weights, const T& gamma, const T& rho,
torch::Tensor& td_errors, T& p_loss_val, T& q_loss_val) {

// nvtx marker
torchfort::nvtx::rangePush("torchfort_train_ddpg");
Expand All @@ -72,10 +73,6 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const
// value functions
q_model.model->train();

// opt
// loss is fixed by algorithm
auto q_loss_func = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));

// policy function
// compute y: use the target models for q_new, no grads
torch::Tensor y_tensor;
Expand All @@ -87,10 +84,11 @@ void train_ddpg(const ModelPack& p_model, const ModelPack& p_model_target, const
}

// backward and update step
// compute loss
// IS-weighted MSE loss: mean(w * (q - y)^2)
torch::Tensor q_old_tensor =
torch::squeeze(q_model.model->forward(std::vector<torch::Tensor>{state_old_tensor, action_old_tensor})[0], 1);
torch::Tensor q_loss_tensor = q_loss_func->forward(q_old_tensor, y_tensor);
td_errors = torch::abs(q_old_tensor - y_tensor).detach();
torch::Tensor q_loss_tensor = torch::mean(is_weights * torch::square(q_old_tensor - y_tensor));

auto state = q_model.state;
if (state->step_train_current % q_model.grad_accumulation_steps == 0) {
Expand Down
15 changes: 7 additions & 8 deletions src/csrc/include/internal/rl/off_policy/sac.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ template <typename T>
void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models,
const std::vector<ModelPack>& q_models_target, torch::Tensor state_old_tensor,
torch::Tensor state_new_tensor, torch::Tensor action_old_tensor, torch::Tensor reward_tensor,
torch::Tensor d_tensor, const std::shared_ptr<AlphaModel>& alpha_model,
torch::Tensor d_tensor, torch::Tensor is_weights, const std::shared_ptr<AlphaModel>& alpha_model,
const std::shared_ptr<torch::optim::Optimizer>& alpha_optimizer,
const std::shared_ptr<BaseLRScheduler>& alpha_lr_scheduler, const T& target_entropy, const T& gamma,
const T& rho, T& p_loss_val, T& q_loss_val) {
const T& rho, torch::Tensor& td_errors, T& p_loss_val, T& q_loss_val) {

// nvtx marker
torchfort::nvtx::rangePush("torchfort_train_sac");
Expand All @@ -84,10 +84,6 @@ void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models
q_model_target.model->train();
}

// opt
// loss is fixed by algorithm
auto q_loss_func = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));

// if we are updating the entropy coefficient, do that first
torch::Tensor alpha_loss;
auto state = p_model.state;
Expand Down Expand Up @@ -168,9 +164,12 @@ void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models
}

// backward and update step
// IS-weighted MSE loss: mean(w * (q - y)^2), summed across critics
// td_errors taken from first critic only
torch::Tensor q_old_tensor =
torch::squeeze(q_models[0].model->forward(std::vector<torch::Tensor>{state_old_tensor, action_old_tensor})[0], 1);
torch::Tensor q_loss_tensor = q_loss_func->forward(q_old_tensor, y_tensor);
td_errors = torch::abs(q_old_tensor - y_tensor).detach();
torch::Tensor q_loss_tensor = torch::mean(is_weights * torch::square(q_old_tensor - y_tensor));
state = q_models[0].state;
if (state->step_train_current % q_models[0].grad_accumulation_steps == 0) {
q_models[0].optimizer->zero_grad();
Expand All @@ -179,7 +178,7 @@ void train_sac(const PolicyPack& p_model, const std::vector<ModelPack>& q_models
// compute loss
q_old_tensor = torch::squeeze(
q_models[i].model->forward(std::vector<torch::Tensor>{state_old_tensor, action_old_tensor})[0], 1);
q_loss_tensor = q_loss_tensor + q_loss_func->forward(q_old_tensor, y_tensor);
q_loss_tensor = q_loss_tensor + torch::mean(is_weights * torch::square(q_old_tensor - y_tensor));
state = q_models[i].state;
if (state->step_train_current % q_models[i].grad_accumulation_steps == 0) {
q_models[i].optimizer->zero_grad();
Expand Down
16 changes: 7 additions & 9 deletions src/csrc/include/internal/rl/off_policy/td3.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ template <typename T>
void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const std::vector<ModelPack>& q_models,
const std::vector<ModelPack>& q_models_target, torch::Tensor state_old_tensor,
torch::Tensor state_new_tensor, torch::Tensor action_old_tensor, torch::Tensor action_new_tensor,
torch::Tensor reward_tensor, torch::Tensor d_tensor, const T& gamma, const T& rho, T& p_loss_val,
T& q_loss_val, bool update_policy) {
torch::Tensor reward_tensor, torch::Tensor d_tensor, torch::Tensor is_weights, const T& gamma,
const T& rho, torch::Tensor& td_errors, T& p_loss_val, T& q_loss_val, bool update_policy) {

// nvtx marker
torchfort::nvtx::rangePush("torchfort_train_td3");
Expand All @@ -76,10 +76,6 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const
q_model.model->train();
}

// opt
// loss is fixed by algorithm
auto q_loss_func = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));

// policy function
// compute y: use the target models for q_new, no grads
torch::Tensor y_tensor;
Expand All @@ -96,18 +92,20 @@ void train_td3(const ModelPack& p_model, const ModelPack& p_model_target, const
}

// backward and update step
// compute loss for critics and zero grads while we are at it
// IS-weighted MSE loss: mean(w * (q - y)^2), summed across critics
// td_errors taken from first critic only (consistent with policy update using q_models[0])
torch::Tensor q_old_tensor =
torch::squeeze(q_models[0].model->forward(std::vector<torch::Tensor>{state_old_tensor, action_old_tensor})[0], 1);
torch::Tensor q_loss_tensor = q_loss_func->forward(q_old_tensor, y_tensor);
td_errors = torch::abs(q_old_tensor - y_tensor).detach();
torch::Tensor q_loss_tensor = torch::mean(is_weights * torch::square(q_old_tensor - y_tensor));
auto state = q_models[0].state;
if (state->step_train_current % q_models[0].grad_accumulation_steps == 0) {
q_models[0].optimizer->zero_grad();
}
for (int i = 1; i < q_models.size(); ++i) {
q_old_tensor = torch::squeeze(
q_models[i].model->forward(std::vector<torch::Tensor>{state_old_tensor, action_old_tensor})[0], 1);
q_loss_tensor = q_loss_tensor + q_loss_func->forward(q_old_tensor, y_tensor);
q_loss_tensor = q_loss_tensor + torch::mean(is_weights * torch::square(q_old_tensor - y_tensor));
state = q_models[i].state;
if (state->step_train_current % q_models[i].grad_accumulation_steps == 0) {
q_models[i].optimizer->zero_grad();
Expand Down
Loading
Loading