Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ target_sources(${PROJECT_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/lr_schedulers/step_lr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/lr_schedulers/linear_lr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/mlp_model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/sac_model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/actor_critic_model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/rl/actor_critic_model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/rl/common_models.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/rl/sac_model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/policy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/off_policy/interface.cpp
Expand Down
13 changes: 13 additions & 0 deletions src/csrc/include/internal/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ struct MLPModel : BaseModel, public std::enable_shared_from_this<BaseModel> {
std::vector<torch::Tensor> biases;
};

struct CriticMLPModel : BaseModel, public std::enable_shared_from_this<BaseModel> {
void setup(const ParamMap& params) override;
std::vector<torch::Tensor> forward(const std::vector<torch::Tensor>& inputs) override;

double dropout;
std::vector<int> layer_sizes;

// Use one of many "standard library" modules.
std::vector<torch::nn::Linear> fc_layers;
std::vector<torch::Tensor> biases;
};

struct SACMLPModel : BaseModel, public std::enable_shared_from_this<BaseModel> {
void setup(const ParamMap& params) override;
std::vector<torch::Tensor> forward(const std::vector<torch::Tensor>& inputs) override;
Expand Down Expand Up @@ -86,6 +98,7 @@ BEGIN_MODEL_REGISTRY

// Add entries for new models in this section.
REGISTER_MODEL(MLP, MLPModel)
REGISTER_MODEL(CriticMLP, CriticMLPModel)
REGISTER_MODEL(SACMLP, SACMLPModel)
REGISTER_MODEL(ActorCriticMLP, ActorCriticMLPModel)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,24 @@ void ActorCriticMLPModel::setup(const ParamMap& params) {

// Implement the forward function.
std::vector<torch::Tensor> ActorCriticMLPModel::forward(const std::vector<torch::Tensor>& inputs) {
// concatenate inputs
auto x = torch::cat(inputs, 1);
x = x.reshape({x.size(0), -1});

// make sure only one tensor (state) is fed
if (inputs.size() != 1) {
THROW_INVALID_USAGE("You have to provide exactly one tensor (state) to the ActorCriticMLPModel");
}

// unpack
auto state = inputs[0];

// expand dims if necessary
if (state.dim() == 1) {
state = state.unsqueeze(0);
}

// flatten everything beyond dim 0:
auto x = state.reshape({state.size(0), -1});

// forward pass
for (int i = 0; i < encoder_layer_sizes.size() - 1; ++i) {
// encoder part
x = torch::relu(encoder_layers[i]->forward(x) + encoder_biases[i]);
Expand Down
99 changes: 99 additions & 0 deletions src/csrc/models/rl/common_models.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

#include <vector>

#include <torch/torch.h>

#include "internal/models.h"
#include "internal/param_map.h"
#include "internal/setup.h"

namespace torchfort {

// MLP model in C++ using libtorch
void CriticMLPModel::setup(const ParamMap& params) {
// Extract params from input map.
std::set<std::string> supported_params{"dropout", "layer_sizes"};
check_params(supported_params, params.keys());

dropout = params.get_param<double>("dropout", 0.0)[0];
layer_sizes = params.get_param<int>("layer_sizes");

// Construct and register submodules.
for (int i = 0; i < layer_sizes.size() - 1; ++i) {
fc_layers.push_back(
register_module("fc" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1])));
if (i < layer_sizes.size() - 2) {
biases.push_back(register_parameter("b" + std::to_string(i), torch::zeros(layer_sizes[i + 1])));
}
}
}

// Implement the forward function.
std::vector<torch::Tensor> CriticMLPModel::forward(const std::vector<torch::Tensor>& inputs) {

// makse sure that exactly two tensors are fed, state and action:
if (inputs.size() != 2) {
THROW_INVALID_USAGE("You have to provide exactly two tensors (state, action) to the CriticMLPModel");
}

// unpack
auto state = inputs[0];
auto action = inputs[1];

// expand dims if necessary
if (state.dim() == 1) {
state = state.unsqueeze(0);
}
if (action.dim() == 1) {
action = action.unsqueeze(0);
}

// flatten everything beyond dim 0:
state = state.reshape({state.size(0), -1});
action = action.reshape({action.size(0), -1});

// concatenate inputs along feature dimension
auto x = torch::cat({state, action}, 1);

// forward pass
for (int i = 0; i < layer_sizes.size() - 1; ++i) {
if (i < layer_sizes.size() - 2) {
x = torch::relu(fc_layers[i]->forward(x) + biases[i]);
x = torch::dropout(x, dropout, is_training());
} else {
x = fc_layers[i]->forward(x);
}
}
return std::vector<torch::Tensor>{x};
}

} // namespace torchfort
19 changes: 16 additions & 3 deletions src/csrc/models/sac_model.cpp → src/csrc/models/rl/sac_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,22 @@ void SACMLPModel::setup(const ParamMap& params) {

// Implement the forward function.
std::vector<torch::Tensor> SACMLPModel::forward(const std::vector<torch::Tensor>& inputs) {
// concatenate inputs
auto x = torch::cat(inputs, 1);
x = x.reshape({x.size(0), -1});

// make sure that exactly two tensors are fed (state and action)
if (inputs.size() != 1) {
THROW_INVALID_USAGE("You have to provide exactly one tensor (state) to the SACMLPModel");
}

// unpack
auto state = inputs[0];

// expand dims if necessary
if (state.dim() == 1) {
state = state.unsqueeze(0);
}

// flatten everything beyond dim 0:
auto x = state.reshape({state.size(0), -1});
torch::Tensor y, z;

for (int i = 0; i < layer_sizes.size() - 1; ++i) {
Expand Down
2 changes: 1 addition & 1 deletion tests/rl/configs/ddpg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ policy_model:
layer_sizes: [1, 16, 1]

critic_model:
type: MLP
type: CriticMLP
parameters:
dropout: 0.0
layer_sizes: [2, 16, 1]
Expand Down
2 changes: 1 addition & 1 deletion tests/rl/configs/sac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ policy_model:
log_sigma_init: 0.

critic_model:
type: MLP
type: CriticMLP
parameters:
dropout: 0.0
layer_sizes: [2, 16, 1]
Expand Down
2 changes: 1 addition & 1 deletion tests/rl/configs/td3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ policy_model:
layer_sizes: [1, 16, 1]

critic_model:
type: MLP
type: CriticMLP
parameters:
dropout: 0.0
layer_sizes: [2, 16, 1]
Expand Down