-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#pragma once | ||
|
||
|
||
#include <autoppl/expression/variable.hpp> | ||
|
||
namespace ppl { | ||
|
||
/* | ||
* Node resulting from adding two Variables together, or a variable | ||
* and an element of type Variable::value_t. Can be casted to value_t | ||
* and returns the sum of the left and right subtrees. | ||
*/ | ||
template <typename LeftChildType, typename RightChildType, typename ValueType> | ||
struct AddNode | ||
{ | ||
using value_t = ValueType; | ||
AddNode(LeftChildType left, RightChildType right) : left_(left), right_(right) { | ||
static_assert(std::is_convertible_v<LeftChildType, value_t>); | ||
static_assert(std::is_convertible_v<RightChildType, value_t>); | ||
} | ||
|
||
operator value_t () const { return left_value() + right_value(); } | ||
|
||
value_t left_value() const { return static_cast<value_t>(left_); } | ||
value_t right_value() const { return static_cast<value_t>(right_); } | ||
|
||
private: | ||
LeftChildType left_; | ||
RightChildType right_; | ||
}; | ||
|
||
////////////////////////////////// | ||
// Operator overloads for AddNode | ||
/////////////////////////////////// | ||
|
||
template <typename ValueType> | ||
AddNode<ValueType, Variable<ValueType>, ValueType> operator+(ValueType left, Variable<ValueType> right) { | ||
return AddNode<ValueType, Variable<ValueType>, ValueType>(left, right); | ||
} | ||
|
||
template <typename ValueType> | ||
AddNode<Variable<ValueType>, ValueType, ValueType> operator+(Variable<ValueType> left, ValueType right) { | ||
return AddNode<Variable<ValueType>, ValueType, ValueType>(left, right); | ||
} | ||
|
||
template <typename ValueType> | ||
AddNode<Variable<ValueType>, Variable<ValueType>, ValueType> operator+(Variable<ValueType> left, Variable<ValueType> right) { | ||
return AddNode<Variable<ValueType>, Variable<ValueType>, ValueType>(left, right); | ||
} | ||
|
||
|
||
/* | ||
* Node resulting from multiplying two Variables together, or a variable | ||
* and an element of type Variable::value_t. Can be casted to value_t | ||
* and returns the product of the left and right subtrees. | ||
*/ | ||
template <typename LeftChildType, typename RightChildType, typename ValueType> | ||
struct MultNode | ||
{ | ||
using value_t = ValueType; | ||
MultNode(LeftChildType left, RightChildType right) : left_(left), right_(right) { | ||
static_assert(std::is_convertible_v<LeftChildType, value_t>); | ||
static_assert(std::is_convertible_v<RightChildType, value_t>); | ||
} | ||
|
||
operator value_t () const { return left_value() * right_value(); } | ||
|
||
value_t left_value() const { return static_cast<value_t>(left_); } | ||
value_t right_value() const { return static_cast<value_t>(right_); } | ||
|
||
private: | ||
LeftChildType left_; | ||
RightChildType right_; | ||
}; | ||
|
||
////////////////////////////////// | ||
// Operator overloads for MultNode | ||
/////////////////////////////////// | ||
|
||
template <typename ValueType> | ||
MultNode<ValueType, Variable<ValueType>, ValueType> operator*(ValueType left, Variable<ValueType> right) { | ||
return MultNode<ValueType, Variable<ValueType>, ValueType>(left, right); | ||
} | ||
|
||
template <typename ValueType> | ||
MultNode<Variable<ValueType>, ValueType, ValueType> operator*(Variable<ValueType> left, ValueType right) { | ||
return MultNode<Variable<ValueType>, ValueType, ValueType>(left, right); | ||
} | ||
|
||
template <typename ValueType> | ||
MultNode<Variable<ValueType>, Variable<ValueType>, ValueType> operator*(Variable<ValueType> left, Variable<ValueType> right) { | ||
return MultNode<Variable<ValueType>, Variable<ValueType>, ValueType>(left, right); | ||
} | ||
|
||
|
||
} // ppl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#include <cmath> | ||
#include <array> | ||
#include "gtest/gtest.h" | ||
#include <autoppl/expression/variable.hpp> | ||
#include <autoppl/expression/operator.hpp> | ||
|
||
namespace ppl { | ||
|
||
////////////////////////////////////////////////////// | ||
// Model with one RV TESTS | ||
////////////////////////////////////////////////////// | ||
|
||
/* | ||
* Fixture for testing one var with distribution. | ||
*/ | ||
struct add_node_fixture : ::testing::Test | ||
{ | ||
protected: | ||
Variable<double> x {3.0}; | ||
Variable<double> y {4.0}; | ||
}; | ||
|
||
TEST_F(add_node_fixture, add_node_test) { | ||
AddNode<Variable<double>, Variable<double>, double> addnode = x + y; | ||
EXPECT_EQ(static_cast<double>(addnode), 7.0); | ||
|
||
MultNode<Variable<double>, Variable<double>, double> multnode = x * y; | ||
EXPECT_EQ(static_cast<double>(multnode), 12.0); | ||
} | ||
|
||
} // ppl |