- sphinx https://shigekikarita.github.io/thxx/sphinx/html
- doxygen https://shigekikarita.github.io/thxx/
https://travis-ci.org/ShigekiKarita/thxx
compiler | conda package | latest zip |
---|---|---|
gcc-7 | ||
gcc-8 | ||
clang-5 | ||
clang-6 | ||
clang-7 |
- gcc 7, 8 or clang 5, 6, 7
- compiler option
-std=c++17
- libtorch 1.0.0 (recommend
conda install pytorch-cpu=1.0.0 -c pytorch
)
- using conda (recommend)
$ conda install pytorch-cpu=1.0.0 -c pytorch
$ make test
$ make example-mnist
- using libtorch latest zip
$ cd
$ make --directory ./third_party libtorch-shared-with-deps-latest.zip
$ make test
$ make example-mnist
for more details, see .travis.yml.
- thxx::meta::Lambda<...> : torch::nn::ModuleHolder<...>
auto x = torch::rand({2, 3, 4});
// single input/output
Lambda f1 = lambda([](auto&& x) { return x * 2; });
auto x1 = f1->forward(x);
CHECK_THAT( x1, testing::TensorEq(x * 2) );
// multi input/output
Lambda f2 = lambda([](auto&& x1, auto&& x2) { return std::make_tuple(x1.relu(), x2 * 2); });
{
auto [a, b] = f2->forward(x, x);
CHECK_THAT( a, testing::TensorEq(torch::relu(x)) );
CHECK_THAT( b, testing::TensorEq(x * 2) );
}
{
// automatic tuple arg unpack
auto [a, b] = f1->forward(std::make_tuple(x, x));
CHECK_THAT( a, testing::TensorEq(torch::relu(x)) );
CHECK_THAT( b, testing::TensorEq(x * 2) );
}
- thxx::meta::Seq<...> : torch::nn::ModuleHolder<...>
auto x1 = torch::rand({2, 3, 4});
// single input/output with Lambda
auto f1 = torch::nn::Linear(4, 5);
auto f2 = sequential(f1, lambda(torch::relu));
auto x2 = f2->forward(x1);
CHECK_THAT( x2, testing::TensorEq(torch::relu(f1->forward(x1)) );
// share submodules/parameters
CHECK_THAT( f2->parameters[0], testing::TensorEq(f1->weight) );
CHECK_THAT( f2->parameters[1], testing::TensorEq(f1->bias) );
// also support multi input/output with Lambda
auto f3 = lambda([](auto&& x1, auto&& x2) { return std::make_tuple(x1, x2 * 2); });
auto f4 = sequential(f3, f3);;
auto [x3, x4] = f4->forward(x1);
CHECK_THAT( x3, testing::TensorEq(x1) );
CHECK_THAT( x4, testing::TensorEq(x1 * 4) );
- Transformer
- MultiHeadedAttention
- PositionalEncoding
- PositionwiseFeedforward
- pad/masking functions
- Normalization
- LayerNorm
- Math (wip)
- copy C++ batched/complex linalg funcitons from https://github.com/ShigekiKarita/thxx-py
- Loss
- label smoothing KLDivLoss
- HDF5 (wip)
- numpy (wip)
mnist example is forked from https://github.com/goldsborough/examples/tree/cpp/cpp/mnist
BSL-1.0