Skip to content
This repository has been archived by the owner on Dec 28, 2023. It is now read-only.

Commit

Permalink
Replace spdlog output with exception messages (#11)
Browse files Browse the repository at this point in the history
* Replace spdlog output with exception messages

* Move spdlog into example directory

* Allow disabling of example and test executables
  • Loading branch information
Isaac Poulton committed Apr 24, 2019
1 parent 6460f0a commit 485e5b9
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
@@ -1,5 +1,5 @@
[submodule "lib/spdlog"]
path = lib/spdlog
path = example/lib/spdlog
url = git@github.com:gabime/spdlog.git
[submodule "lib/msgpack-c"]
path = example/lib/msgpack-c
Expand Down
18 changes: 11 additions & 7 deletions CMakeLists.txt
Expand Up @@ -24,7 +24,6 @@ list(APPEND CPPCHECK_ARGS
--suppressions-list=${CMAKE_CURRENT_LIST_DIR}/CppCheckSuppressions.txt
-I ${CMAKE_CURRENT_LIST_DIR}/src
-I ${CMAKE_CURRENT_LIST_DIR}/include
-I ${CMAKE_CURRENT_LIST_DIR}/lib/spdlog/include
-I ${CMAKE_CURRENT_LIST_DIR}/example
${CMAKE_CURRENT_LIST_DIR}/src
${CMAKE_CURRENT_LIST_DIR}/example
Expand All @@ -42,8 +41,6 @@ find_package(Torch REQUIRED)
if (TORCH_CXX_FLAGS)
set(CMAKE_CXX_FLAGS ${TORCH_CXX_FLAGS})
endif()
## Spdlog
add_subdirectory(lib/spdlog)

# Define targets
add_library(cpprl STATIC "")
Expand All @@ -64,18 +61,25 @@ endif(MSVC)
set(CPPRL_INCLUDE_DIRS
include
src
lib/spdlog/include
${TORCH_INCLUDE_DIRS}
)
target_include_directories(cpprl PRIVATE ${CPPRL_INCLUDE_DIRS})
target_include_directories(cpprl_tests PRIVATE ${CPPRL_INCLUDE_DIRS})
if (CPPRL_BUILD_TESTS)
target_include_directories(cpprl_tests PRIVATE ${CPPRL_INCLUDE_DIRS})
endif(CPPRL_BUILD_TESTS)

# Linking
target_link_libraries(cpprl torch ${TORCH_LIBRARIES})
target_link_libraries(cpprl_tests torch ${TORCH_LIBRARIES})
target_link_libraries(cpprl torch ${TORCH_LIBRARIES})
if (CPPRL_BUILD_TESTS)
target_link_libraries(cpprl_tests torch ${TORCH_LIBRARIES})
endif(CPPRL_BUILD_TESTS)

# Example
add_subdirectory(example)
option(CPPRL_BUILD_EXAMPLE "Whether or not to build the CppRl Gym example" ON)
if (CPPRL_BUILD_EXAMPLE)
add_subdirectory(example)
endif(CPPRL_BUILD_EXAMPLE)

# Recurse into source tree
add_subdirectory(src)
4 changes: 4 additions & 0 deletions example/CMakeLists.txt
Expand Up @@ -3,8 +3,11 @@ add_executable(gym_client gym_client.cpp communicator.cpp)
set(LIB_DIR ${CMAKE_CURRENT_LIST_DIR}/lib)
set(CPPZMQ_DIR ${LIB_DIR}/cppzmq)
set(MSGPACK_DIR ${LIB_DIR}/msgpack-c)
set(SPDLOG_DIR ${LIB_DIR}/spdlog)
set(ZMQ_DIR ${LIB_DIR}/libzmq)

# Spdlog
add_subdirectory(${SPDLOG_DIR})
# ZMQ
option(ZMQ_BUILD_TESTS "" OFF)
add_subdirectory(${ZMQ_DIR})
Expand All @@ -16,6 +19,7 @@ target_include_directories(gym_client
../lib/spdlog/include
${CPPZMQ_DIR}
${MSGPACK_DIR}/include
${SPDLOG_DIR}/include
${ZMQ_DIR}/include
)

Expand Down
1 change: 0 additions & 1 deletion src/algorithms/ppo.cpp
@@ -1,7 +1,6 @@
#include <chrono>
#include <memory>

#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/algorithms/ppo.h"
Expand Down
8 changes: 3 additions & 5 deletions src/distributions/bernoulli.cpp
@@ -1,6 +1,5 @@
#include <ATen/core/Reduction.h>
#include <c10/util/ArrayRef.h>
#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/distributions/bernoulli.h"
Expand All @@ -13,15 +12,14 @@ Bernoulli::Bernoulli(const torch::Tensor *probs,
{
if ((probs == nullptr) == (logits == nullptr))
{
spdlog::error("Either probs or logits is required, but not both");
throw std::exception();
throw std::runtime_error("Either probs or logits is required, but not both");
}

if (probs != nullptr)
{
if (probs->dim() < 1)
{
throw std::exception();
throw std::runtime_error("Probabilities tensor must have at least one dimension");
}
this->probs = *probs;
// 1.21e-7 is used as the epsilon to match PyTorch's Python results as closely
Expand All @@ -33,7 +31,7 @@ Bernoulli::Bernoulli(const torch::Tensor *probs,
{
if (logits->dim() < 1)
{
throw std::exception();
throw std::runtime_error("Logits tensor must have at least one dimension");
}
this->logits = *logits;
this->probs = torch::sigmoid(*logits);
Expand Down
8 changes: 3 additions & 5 deletions src/distributions/categorical.cpp
@@ -1,5 +1,4 @@
#include <c10/util/ArrayRef.h>
#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/distributions/categorical.h"
Expand All @@ -12,15 +11,14 @@ Categorical::Categorical(const torch::Tensor *probs,
{
if ((probs == nullptr) == (logits == nullptr))
{
spdlog::error("Either probs or logits is required, but not both");
throw std::exception();
throw std::runtime_error("Either probs or logits is required, but not both");
}

if (probs != nullptr)
{
if (probs->dim() < 1)
{
throw std::exception();
throw std::runtime_error("Probabilities tensor must have at least one dimension");
}
this->probs = *probs / probs->sum(-1, true);
// 1.21e-7 is used as the epsilon to match PyTorch's Python results as closely
Expand All @@ -32,7 +30,7 @@ Categorical::Categorical(const torch::Tensor *probs,
{
if (logits->dim() < 1)
{
throw std::exception();
throw std::runtime_error("Logits tensor must have at least one dimension");
}
this->logits = *logits - logits->logsumexp(-1, true);
this->probs = torch::softmax(this->logits, -1);
Expand Down
6 changes: 1 addition & 5 deletions src/generators/feed_forward_generator.cpp
@@ -1,7 +1,6 @@
#include <algorithm>
#include <vector>

#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/generators/feed_forward_generator.h"
Expand Down Expand Up @@ -44,10 +43,7 @@ MiniBatch FeedForwardGenerator::next()
{
if (index >= indices.size(0))
{
spdlog::error("No minibatches left in generator. Index {}, minibatch "
"count: {}.",
index, indices.size(0));
throw std::exception();
throw std::runtime_error("No minibatches left in generator.");
}

MiniBatch mini_batch;
Expand Down
6 changes: 1 addition & 5 deletions src/generators/recurrent_generator.cpp
@@ -1,7 +1,6 @@
#include <algorithm>
#include <vector>

#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/generators/recurrent_generator.h"
Expand Down Expand Up @@ -49,10 +48,7 @@ MiniBatch RecurrentGenerator::next()
{
if (index >= indices.size(0))
{
spdlog::error("No minibatches left in generator. Index {}, minibatch "
"count: {}.",
index, indices.size(0));
throw std::exception();
throw std::runtime_error("No minibatches left in generator.");
}

MiniBatch mini_batch;
Expand Down
4 changes: 1 addition & 3 deletions src/model/policy.cpp
@@ -1,4 +1,3 @@
#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/model/policy.h"
Expand Down Expand Up @@ -35,8 +34,7 @@ PolicyImpl::PolicyImpl(ActionSpace action_space, std::shared_ptr<NNBase> base)
}
else
{
spdlog::error("Action space {} not supported", action_space.type);
throw std::exception();
throw std::runtime_error("Action space " + action_space.type + " not supported");
}
register_module("output", output_layer);
}
Expand Down
26 changes: 13 additions & 13 deletions src/storage.cpp
Expand Up @@ -2,7 +2,6 @@
#include <vector>

#include <c10/util/ArrayRef.h>
#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/generators/feed_forward_generator.h"
Expand Down Expand Up @@ -97,13 +96,14 @@ std::unique_ptr<Generator> RolloutStorage::feed_forward_generator(
auto batch_size = num_processes * num_steps;
if (batch_size < num_mini_batch)
{
spdlog::error("PPO needs the number of processes ({}) * the number of "
"steps ({}) = {} to be greater than or equal to the number "
"of minibatches ({})",
num_processes,
num_steps,
num_mini_batch);
throw std::exception();
throw std::runtime_error("PPO needs the number of processes (" +
std::to_string(num_processes) +
") * the number of steps (" +
std::to_string(num_steps) + ") = " +
std::to_string(num_processes * num_steps) +
" to be greater than or equal to the number of minibatches (" +
std::to_string(num_mini_batch) +
")");
}
auto mini_batch_size = batch_size / num_mini_batch;
return std::make_unique<FeedForwardGenerator>(
Expand Down Expand Up @@ -143,11 +143,11 @@ std::unique_ptr<Generator> RolloutStorage::recurrent_generator(
auto num_processes = actions.size(1);
if (num_processes < num_mini_batch)
{
spdlog::error("PPO needs the number of processes ({}) to be greater than or"
" equal to the number of minibatches ({})",
num_processes,
num_mini_batch);
throw std::exception();
throw std::runtime_error("PPO needs the number of processes (" +
std::to_string(num_processes) +
") to be greater than or equal to the number of minibatches (" +
std::to_string(num_mini_batch) +
")");
}
return std::make_unique<RecurrentGenerator>(
num_processes,
Expand Down
12 changes: 2 additions & 10 deletions src/third_party/doctest.cpp
@@ -1,11 +1,3 @@
#define DOCTEST_CONFIG_IMPLEMENT
#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN

#include <spdlog/spdlog.h>

#include "third_party/doctest.h"

int main(int argc, char **argv)
{
spdlog::set_level(spdlog::level::off);
return doctest::Context(argc, argv).run();
}
#include "third_party/doctest.h"

0 comments on commit 485e5b9

Please sign in to comment.