Skip to content

Commit

Permalink
Merge 9d62612 into ca008a4
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobaustin123 committed Apr 18, 2020
2 parents ca008a4 + 9d62612 commit 42c5170
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
96 changes: 96 additions & 0 deletions include/autoppl/expression/operator.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#pragma once


#include <autoppl/expression/variable.hpp>

namespace ppl {

/*
* Node resulting from adding two Variables together, or a variable
* and an element of type Variable::value_t. Can be casted to value_t
* and returns the sum of the left and right subtrees.
*/
template <typename LeftChildType, typename RightChildType, typename ValueType>
struct AddNode
{
using value_t = ValueType;
AddNode(LeftChildType left, RightChildType right) : left_(left), right_(right) {
static_assert(std::is_convertible_v<LeftChildType, value_t>);
static_assert(std::is_convertible_v<RightChildType, value_t>);
}

operator value_t () const { return left_value() + right_value(); }

value_t left_value() const { return static_cast<value_t>(left_); }
value_t right_value() const { return static_cast<value_t>(right_); }

private:
LeftChildType left_;
RightChildType right_;
};

//////////////////////////////////
// Operator overloads for AddNode
///////////////////////////////////

template <typename ValueType>
AddNode<ValueType, Variable<ValueType>, ValueType> operator+(ValueType left, Variable<ValueType> right) {
return AddNode<ValueType, Variable<ValueType>, ValueType>(left, right);
}

template <typename ValueType>
AddNode<Variable<ValueType>, ValueType, ValueType> operator+(Variable<ValueType> left, ValueType right) {
return AddNode<Variable<ValueType>, ValueType, ValueType>(left, right);
}

template <typename ValueType>
AddNode<Variable<ValueType>, Variable<ValueType>, ValueType> operator+(Variable<ValueType> left, Variable<ValueType> right) {
return AddNode<Variable<ValueType>, Variable<ValueType>, ValueType>(left, right);
}


/*
* Node resulting from multiplying two Variables together, or a variable
* and an element of type Variable::value_t. Can be casted to value_t
* and returns the product of the left and right subtrees.
*/
template <typename LeftChildType, typename RightChildType, typename ValueType>
struct MultNode
{
using value_t = ValueType;
MultNode(LeftChildType left, RightChildType right) : left_(left), right_(right) {
static_assert(std::is_convertible_v<LeftChildType, value_t>);
static_assert(std::is_convertible_v<RightChildType, value_t>);
}

operator value_t () const { return left_value() * right_value(); }

value_t left_value() const { return static_cast<value_t>(left_); }
value_t right_value() const { return static_cast<value_t>(right_); }

private:
LeftChildType left_;
RightChildType right_;
};

//////////////////////////////////
// Operator overloads for MultNode
///////////////////////////////////

template <typename ValueType>
MultNode<ValueType, Variable<ValueType>, ValueType> operator*(ValueType left, Variable<ValueType> right) {
return MultNode<ValueType, Variable<ValueType>, ValueType>(left, right);
}

template <typename ValueType>
MultNode<Variable<ValueType>, ValueType, ValueType> operator*(Variable<ValueType> left, ValueType right) {
return MultNode<Variable<ValueType>, ValueType, ValueType>(left, right);
}

template <typename ValueType>
MultNode<Variable<ValueType>, Variable<ValueType>, ValueType> operator*(Variable<ValueType> left, Variable<ValueType> right) {
return MultNode<Variable<ValueType>, Variable<ValueType>, ValueType>(left, right);
}


} // ppl
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ endif()

add_executable(expression_unittest
${CMAKE_CURRENT_SOURCE_DIR}/expression/model_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/expression/operator_unittest.cpp
)
target_compile_options(expression_unittest PRIVATE -g -Wall -Werror -Wextra)
target_include_directories(expression_unittest PRIVATE ${GTEST_DIR}/include)
Expand Down
31 changes: 31 additions & 0 deletions test/expression/operator_unittest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <cmath>
#include <array>
#include "gtest/gtest.h"
#include <autoppl/expression/variable.hpp>
#include <autoppl/expression/operator.hpp>

namespace ppl {

//////////////////////////////////////////////////////
// Model with one RV TESTS
//////////////////////////////////////////////////////

/*
* Fixture for testing one var with distribution.
*/
struct add_node_fixture : ::testing::Test
{
protected:
Variable<double> x {3.0};
Variable<double> y {4.0};
};

TEST_F(add_node_fixture, add_node_test) {
AddNode<Variable<double>, Variable<double>, double> addnode = x + y;
EXPECT_EQ(static_cast<double>(addnode), 7.0);

MultNode<Variable<double>, Variable<double>, double> multnode = x * y;
EXPECT_EQ(static_cast<double>(multnode), 12.0);
}

} // ppl

0 comments on commit 42c5170

Please sign in to comment.