Skip to content

Commit

Permalink
James.yang/concepts cpp20 (#34)
Browse files Browse the repository at this point in the history
* Extended to c++20 concepts

* Finished full changes to all concepts and add travis CI

* Try xenial instead of trusty

* Fix mock dist expr with no pdf

* Change to bionic

* Set g++-10 build to use c++20 features

* Change setup to use safer cd

* Try &> method instead
  • Loading branch information
JamesYang007 committed Jun 3, 2020
1 parent e75fa3e commit b2a913f
Show file tree
Hide file tree
Showing 30 changed files with 480 additions and 28 deletions.
15 changes: 15 additions & 0 deletions .travis.yml
Expand Up @@ -100,6 +100,21 @@ jobs:
'libopenblas-dev', 'liblapack-dev',
'libarpack2-dev']

- stage: test
dist: bionic
os: linux
compiler: gcc
env:
- CXX_COMPILER="g++-10"
- CC_COMPILER="gcc-10"
- CMAKE_OPTIONS="-DCMAKE_CXX_FLAGS=-std=c++20"
addons:
apt:
sources: ['ubuntu-toolchain-r-test']
packages: ['g++-10', 'ninja-build',
'libopenblas-dev', 'liblapack-dev',
'libarpack2-dev']

- stage: deploy
dist: bionic
os: linux
Expand Down
113 changes: 111 additions & 2 deletions include/autoppl/expr_builder.hpp
Expand Up @@ -31,6 +31,9 @@ namespace details {
* - is_var_expr_v<T> true => T
* Assumes each condition is non-overlapping.
*/

#if __cplusplus <= 201703L

template <class T, class = void>
struct convert_to_param
{};
Expand Down Expand Up @@ -65,10 +68,40 @@ struct convert_to_param<T,
using type = T;
};

#else

template <class T>
struct convert_to_param;

template <class T>
requires util::var<std::decay_t<T>>
struct convert_to_param<T>
{
using type = expr::VariableViewer<std::decay_t<T>>;
};

template <class T>
requires std::is_arithmetic_v<std::decay_t<T>>
struct convert_to_param<T>
{
using type = expr::Constant<std::decay_t<T>>;
};

template <class T>
requires util::var_expr<std::decay_t<T>>
struct convert_to_param<T>
{
using type = T;
};

#endif

template <class T>
using convert_to_param_t =
typename convert_to_param<T>::type;

#if __cplusplus <= 201703L

/**
* Checks if valid distribution parameter:
* - can be arithmetic
Expand All @@ -94,20 +127,46 @@ inline constexpr bool is_not_both_arithmetic_v =
std::is_arithmetic_v<std::decay_t<T2>>)
;

#else

template <class T>
concept valid_dist_param =
std::is_arithmetic_v<std::decay_t<T>> ||
(util::var<std::decay_t<T>> &&
!std::is_rvalue_reference_v<T> &&
!std::is_const_v<std::remove_reference_t<T>>) ||
(util::var_expr<std::decay_t<T>>)
;

template <class T1, class T2>
concept not_both_arithmetic =
!(std::is_arithmetic_v<std::decay_t<T1>> &&
std::is_arithmetic_v<std::decay_t<T2>>)
;

#endif

} // namespace details

/**
* Builds a Uniform expression only when the parameters
* are both valid continuous distribution parameter types.
* See var_expr.hpp for more information.
*/
#if __cplusplus <= 201703L
template <class MinType, class MaxType
, class = std::enable_if_t<
details::is_valid_dist_param_v<MinType> &&
details::is_valid_dist_param_v<MaxType>
> >
inline constexpr auto uniform(MinType&& min_expr,
MaxType&& max_expr)
#else
template <details::valid_dist_param MinType
, details::valid_dist_param MaxType>
inline constexpr auto uniform(MinType&& min_expr,
MaxType&& max_expr)
#endif
{
using min_t = details::convert_to_param_t<MinType>;
using max_t = details::convert_to_param_t<MaxType>;
Expand All @@ -123,13 +182,20 @@ inline constexpr auto uniform(MinType&& min_expr,
* are both valid continuous distribution parameter types.
* See var_expr.hpp for more information.
*/
#if __cplusplus <= 201703L
template <class MeanType, class StddevType
, class = std::enable_if_t<
details::is_valid_dist_param_v<MeanType> &&
details::is_valid_dist_param_v<StddevType>
> >
inline constexpr auto normal(MeanType&& mean_expr,
StddevType&& stddev_expr)
#else
template <details::valid_dist_param MeanType
, details::valid_dist_param StddevType>
inline constexpr auto normal(MeanType&& mean_expr,
StddevType&& stddev_expr)
#endif
{
using mean_t = details::convert_to_param_t<MeanType>;
using stddev_t = details::convert_to_param_t<StddevType>;
Expand All @@ -145,11 +211,16 @@ inline constexpr auto normal(MeanType&& mean_expr,
* is a valid discrete distribution parameter type.
* See var_expr.hpp for more information.
*/
#if __cplusplus <= 201703L
template <class ProbType
, class = std::enable_if_t<
details::is_valid_dist_param_v<ProbType>
> >
inline constexpr auto bernoulli(ProbType&& p_expr)
#else
template <details::valid_dist_param ProbType>
inline constexpr auto bernoulli(ProbType&& p_expr)
#endif
{
using p_t = details::convert_to_param_t<ProbType>;
p_t wrap_p_expr = std::forward<ProbType>(p_expr);
Expand All @@ -166,8 +237,10 @@ inline constexpr auto bernoulli(ProbType&& p_expr)
* Ex. x |= uniform(0,1)
*/
template <class VarType, class DistType>
inline constexpr auto operator|=(util::Var<VarType>& var,
const util::DistExpr<DistType>& dist) { return expr::EqNode(var.self(), dist.self()); }
inline constexpr auto operator|=(
util::Var<VarType>& var,
const util::DistExpr<DistType>& dist)
{ return expr::EqNode(var.self(), dist.self()); }

/**
* Builds a GlueNode to "glue" the left expression with the right
Expand All @@ -185,6 +258,8 @@ inline constexpr auto operator,(const util::ModelExpr<LHSNodeType>& lhs,

namespace details {

#if __cplusplus <= 201703L

/**
* Concept of valid variable expression parameter
* for the operator overloads:
Expand All @@ -198,6 +273,16 @@ inline constexpr bool is_valid_op_param_v =
util::is_var_expr_v<std::decay_t<T>> ||
util::is_var_v<std::decay_t<T>>
;
#else

template <class T>
concept valid_op_param =
std::is_arithmetic_v<std::decay_t<T>> ||
util::var_expr<std::decay_t<T>> ||
util::var<std::decay_t<T>>
;

#endif

template <class Op, class LHSType, class RHSType>
inline constexpr auto operator_helper(LHSType&& lhs, RHSType&& rhs)
Expand All @@ -224,12 +309,18 @@ inline constexpr auto operator_helper(LHSType&& lhs, RHSType&& rhs)
* SFINAE to ensure concepts are placed.
*/

#if __cplusplus <= 201703L
template <class LHSType, class RHSType
, class = std::enable_if_t<
details::is_not_both_arithmetic_v<LHSType, RHSType> &&
details::is_valid_op_param_v<LHSType> &&
details::is_valid_op_param_v<RHSType>
> >
#else
template <details::valid_op_param LHSType
, details::valid_op_param RHSType>
requires details::not_both_arithmetic<LHSType, RHSType>
#endif
inline constexpr auto operator+(LHSType&& lhs,
RHSType&& rhs)
{
Expand All @@ -238,38 +329,56 @@ inline constexpr auto operator+(LHSType&& lhs,
std::forward<RHSType>(rhs));
}

#if __cplusplus <= 201703L
template <class LHSType, class RHSType
, class = std::enable_if_t<
details::is_not_both_arithmetic_v<LHSType, RHSType> &&
details::is_valid_op_param_v<LHSType> &&
details::is_valid_op_param_v<RHSType>
> >
#else
template <details::valid_op_param LHSType
, details::valid_op_param RHSType>
requires details::not_both_arithmetic<LHSType, RHSType>
#endif
inline constexpr auto operator-(LHSType&& lhs, RHSType&& rhs)
{
return details::operator_helper<expr::SubOp>(
std::forward<LHSType>(lhs),
std::forward<RHSType>(rhs));
}

#if __cplusplus <= 201703L
template <class LHSType, class RHSType
, class = std::enable_if_t<
details::is_not_both_arithmetic_v<LHSType, RHSType> &&
details::is_valid_op_param_v<LHSType> &&
details::is_valid_op_param_v<RHSType>
> >
#else
template <details::valid_op_param LHSType
, details::valid_op_param RHSType>
requires details::not_both_arithmetic<LHSType, RHSType>
#endif
inline constexpr auto operator*(LHSType&& lhs, RHSType&& rhs)
{
return details::operator_helper<expr::MultOp>(
std::forward<LHSType>(lhs),
std::forward<RHSType>(rhs));
}

#if __cplusplus <= 201703L
template <class LHSType, class RHSType
, class = std::enable_if_t<
details::is_not_both_arithmetic_v<LHSType, RHSType> &&
details::is_valid_op_param_v<LHSType> &&
details::is_valid_op_param_v<RHSType>
> >
#else
template <details::valid_op_param LHSType
, details::valid_op_param RHSType>
requires details::not_both_arithmetic<LHSType, RHSType>
#endif
inline constexpr auto operator/(LHSType&& lhs, RHSType&& rhs)
{
return details::operator_helper<expr::DivOp>(
Expand Down
19 changes: 18 additions & 1 deletion include/autoppl/expression/distribution/bernoulli.hpp
Expand Up @@ -7,10 +7,17 @@
namespace ppl {
namespace expr {

#if __cplusplus <= 201703L
template <typename p_type>
#else
template <util::var_expr p_type>
#endif
struct Bernoulli : util::DistExpr<Bernoulli<p_type>>
{

#if __cplusplus <= 201703L
static_assert(util::assert_is_var_expr_v<p_type>);
#endif

using value_t = util::disc_param_t;
using param_value_t = typename util::var_expr_traits<p_type>::value_t;
Expand Down Expand Up @@ -43,7 +50,17 @@ struct Bernoulli : util::DistExpr<Bernoulli<p_type>>
else return std::numeric_limits<dist_value_t>::lowest();
}

param_value_t p(size_t index=0) const { return std::max(std::min(p_.get_value(index), static_cast<param_value_t>(max())), static_cast<param_value_t>(min())); }
param_value_t p(size_t index=0) const
{
return std::max(
std::min(
p_.get_value(index),
static_cast<param_value_t>(max())
),
static_cast<param_value_t>(min())
);
}

value_t min() const { return 0; }
value_t max() const { return 1; }

Expand Down
7 changes: 7 additions & 0 deletions include/autoppl/expression/distribution/normal.hpp
Expand Up @@ -14,11 +14,18 @@
namespace ppl {
namespace expr {

#if __cplusplus <= 201703L
template <typename mean_type, typename stddev_type>
#else
template <util::var_expr mean_type, util::var_expr stddev_type>
#endif
struct Normal : util::DistExpr<Normal<mean_type, stddev_type>>
{

#if __cplusplus <= 201703L
static_assert(util::assert_is_var_expr_v<mean_type>);
static_assert(util::assert_is_var_expr_v<stddev_type>);
#endif

using value_t = util::cont_param_t;
using base_t = util::DistExpr<Normal<mean_type, stddev_type>>;
Expand Down
7 changes: 7 additions & 0 deletions include/autoppl/expression/distribution/uniform.hpp
Expand Up @@ -8,11 +8,18 @@
namespace ppl {
namespace expr {

#if __cplusplus <= 201703L
template <typename min_type, typename max_type>
#else
template <util::var_expr min_type, util::var_expr max_type>
#endif
struct Uniform : util::DistExpr<Uniform<min_type, max_type>>
{

#if __cplusplus <= 201703L
static_assert(util::assert_is_var_expr_v<min_type>);
static_assert(util::assert_is_var_expr_v<max_type>);
#endif

using value_t = util::cont_param_t;
using base_t = util::DistExpr<Uniform<min_type, max_type>>;
Expand Down
15 changes: 15 additions & 0 deletions include/autoppl/expression/model/eq_node.hpp
Expand Up @@ -14,11 +14,18 @@ namespace expr {
* This class represents a "node" in the model expression
* that relates a var with a distribution.
*/
#if __cplusplus <= 201703L
template <class VarType, class DistType>
#else
template <util::var VarType, util::dist_expr DistType>
#endif
struct EqNode : util::ModelExpr<EqNode<VarType, DistType>>
{

#if __cplusplus <= 201703L
static_assert(util::assert_is_var_v<VarType>);
static_assert(util::assert_is_dist_expr_v<DistType>);
#endif

using var_t = VarType;
using dist_t = DistType;
Expand Down Expand Up @@ -71,7 +78,11 @@ struct EqNode : util::ModelExpr<EqNode<VarType, DistType>>
{
// if parameter, find the corresponding variable
// in vars and return the AD log-pdf with this variable.
#if __cplusplus <= 201703L
if constexpr (util::is_param_v<var_t>) {
#else
if constexpr (util::param<var_t>) {
#endif
const void* addr = &orig_var_ref_.get();
auto it = std::find(keys.begin(), keys.end(), addr);
assert(it != keys.end());
Expand All @@ -82,7 +93,11 @@ struct EqNode : util::ModelExpr<EqNode<VarType, DistType>>
// if data, return sum of log_pdf where each element
// is a constant AD node containing each value of data.
// note: data is not copied at any point.
#if __cplusplus <= 201703L
else if constexpr (util::is_data_v<var_t>) {
#else
else if constexpr (util::data<var_t>) {
#endif
const auto& var = this->get_variable();
size_t idx = 0;
const size_t size = var.size();
Expand Down

0 comments on commit b2a913f

Please sign in to comment.