From 9d626124e12a6172faa0f3c70ee61350268489e9 Mon Sep 17 00:00:00 2001 From: Jacob Austin Date: Sat, 18 Apr 2020 18:47:53 -0400 Subject: [PATCH] added initial support for AddNode and MultNodes --- include/autoppl/expression/operator.hpp | 96 +++++++++++++++++++++++++ test/CMakeLists.txt | 1 + test/expression/operator_unittest.cpp | 31 ++++++++ 3 files changed, 128 insertions(+) create mode 100644 include/autoppl/expression/operator.hpp create mode 100644 test/expression/operator_unittest.cpp diff --git a/include/autoppl/expression/operator.hpp b/include/autoppl/expression/operator.hpp new file mode 100644 index 00000000..f50dbd94 --- /dev/null +++ b/include/autoppl/expression/operator.hpp @@ -0,0 +1,96 @@ +#pragma once + + +#include + +namespace ppl { + +/* + * Node resulting from adding two Variables together, or a variable + * and an element of type Variable::value_t. Can be casted to value_t + * and returns the sum of the left and right subtrees. +*/ +template +struct AddNode +{ + using value_t = ValueType; + AddNode(LeftChildType left, RightChildType right) : left_(left), right_(right) { + static_assert(std::is_convertible_v); + static_assert(std::is_convertible_v); + } + + operator value_t () const { return left_value() + right_value(); } + + value_t left_value() const { return static_cast(left_); } + value_t right_value() const { return static_cast(right_); } + +private: + LeftChildType left_; + RightChildType right_; +}; + +////////////////////////////////// +// Operator overloads for AddNode +/////////////////////////////////// + +template +AddNode, ValueType> operator+(ValueType left, Variable right) { + return AddNode, ValueType>(left, right); +} + +template +AddNode, ValueType, ValueType> operator+(Variable left, ValueType right) { + return AddNode, ValueType, ValueType>(left, right); +} + +template +AddNode, Variable, ValueType> operator+(Variable left, Variable right) { + return AddNode, Variable, ValueType>(left, right); +} + + +/* + * Node resulting from multiplying two Variables together, or a variable + * and an element of type Variable::value_t. Can be casted to value_t + * and returns the product of the left and right subtrees. +*/ +template +struct MultNode +{ + using value_t = ValueType; + MultNode(LeftChildType left, RightChildType right) : left_(left), right_(right) { + static_assert(std::is_convertible_v); + static_assert(std::is_convertible_v); + } + + operator value_t () const { return left_value() * right_value(); } + + value_t left_value() const { return static_cast(left_); } + value_t right_value() const { return static_cast(right_); } + +private: + LeftChildType left_; + RightChildType right_; +}; + +////////////////////////////////// +// Operator overloads for MultNode +/////////////////////////////////// + +template +MultNode, ValueType> operator*(ValueType left, Variable right) { + return MultNode, ValueType>(left, right); +} + +template +MultNode, ValueType, ValueType> operator*(Variable left, ValueType right) { + return MultNode, ValueType, ValueType>(left, right); +} + +template +MultNode, Variable, ValueType> operator*(Variable left, Variable right) { + return MultNode, Variable, ValueType>(left, right); +} + + +} // ppl \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 670e3d9a..95071834 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -11,6 +11,7 @@ endif() add_executable(expression_unittest ${CMAKE_CURRENT_SOURCE_DIR}/expression/model_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/expression/operator_unittest.cpp ) target_compile_options(expression_unittest PRIVATE -g -Wall -Werror -Wextra) target_include_directories(expression_unittest PRIVATE ${GTEST_DIR}/include) diff --git a/test/expression/operator_unittest.cpp b/test/expression/operator_unittest.cpp new file mode 100644 index 00000000..65619f6f --- /dev/null +++ b/test/expression/operator_unittest.cpp @@ -0,0 +1,31 @@ +#include +#include +#include "gtest/gtest.h" +#include +#include + +namespace ppl { + +////////////////////////////////////////////////////// +// Model with one RV TESTS +////////////////////////////////////////////////////// + +/* + * Fixture for testing one var with distribution. + */ +struct add_node_fixture : ::testing::Test +{ +protected: + Variable x {3.0}; + Variable y {4.0}; +}; + +TEST_F(add_node_fixture, add_node_test) { + AddNode, Variable, double> addnode = x + y; + EXPECT_EQ(static_cast(addnode), 7.0); + + MultNode, Variable, double> multnode = x * y; + EXPECT_EQ(static_cast(multnode), 12.0); +} + +} // ppl \ No newline at end of file