Skip to content

Commit

Permalink
Fix aliasing issue (#32) (#33)
Browse files Browse the repository at this point in the history
* Fix autodiff header figure for changes in v0.5.0

* Fix issue 32 - preventing aliasing

* Increment version to 0.5.1
  • Loading branch information
allanleal committed Jun 18, 2019
1 parent 636dc8f commit 31b0020
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 24 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(CCache)

# Name and details of the project
project(autodiff VERSION 0.5.0 LANGUAGES CXX)
project(autodiff VERSION 0.5.1 LANGUAGES CXX)

# Include the cmake variables with values for installation directories
include(GNUInstallDirs)
Expand Down
20 changes: 15 additions & 5 deletions autodiff/forward/forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,35 +470,45 @@ struct Dual
template<typename U, enableif<isNumber<U> || isExpr<U>>...>
Dual& operator=(U&& other)
{
assign(*this, std::forward<U>(other));
Dual tmp;
assign(tmp, std::forward<U>(other));
assign(*this, tmp);
return *this;
}

template<typename U, enableif<isNumber<U> || isExpr<U>>...>
Dual& operator+=(U&& other)
{
assignAdd(*this, std::forward<U>(other));
Dual tmp;
assign(tmp, std::forward<U>(other));
assignAdd(*this, tmp);
return *this;
}

template<typename U, enableif<isNumber<U> || isExpr<U>>...>
Dual& operator-=(U&& other)
{
assignSub(*this, std::forward<U>(other));
Dual tmp;
assign(tmp, std::forward<U>(other));
assignSub(*this, tmp);
return *this;
}

template<typename U, enableif<isNumber<U> || isExpr<U>>...>
Dual& operator*=(U&& other)
{
assignMul(*this, std::forward<U>(other));
Dual tmp;
assign(tmp, std::forward<U>(other));
assignMul(*this, tmp);
return *this;
}

template<typename U, enableif<isNumber<U> || isExpr<U>>...>
Dual& operator/=(U&& other)
{
assignDiv(*this, std::forward<U>(other));
Dual tmp;
assign(tmp, std::forward<U>(other));
assignDiv(*this, tmp);
return *this;
}
};
Expand Down
54 changes: 36 additions & 18 deletions test/forward.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,24 @@ TEST_CASE("autodiff::dual tests", "[dual]")
REQUIRE( x == 10 );
}

SECTION("aliasing tests")
{
x = 1; x = x + 3*x - 2*x + x;
REQUIRE( x == 3 );

x = 1; x += x + 3*x - 2*x + x;
REQUIRE( x == 4 );

x = 1; x -= x + 3*x - 2*x + x;
REQUIRE( x == -2 );

x = 1; x *= x + 3*x - 2*x + x;
REQUIRE( x == 3 );

x = 1; x /= x + x;
REQUIRE( x == 0.5 );
}

SECTION("testing comparison operators")
{
x = 6;
Expand Down Expand Up @@ -394,38 +412,38 @@ TEST_CASE("autodiff::dual tests", "[dual]")

f = [](dual x, dual y) -> dual { return x *= 2; };
REQUIRE( f(x, y) == approx(2.0 * x) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == 2.0 );
REQUIRE( derivative(f, wrt(y), at(x, y)) == 0.0 );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(2.0) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(0.0) );

f = [](dual x, dual y) -> dual { return x *= y; };
REQUIRE( f(x, y) == approx(x * y) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == y );
REQUIRE( derivative(f, wrt(y), at(x, y)) == x );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(y) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(x) );

f = [](dual x, dual y) -> dual { return x *= -x; };
REQUIRE( f(x, y) == approx(-x * x) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(-2.0 * x) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == 0.0 );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(0.0) );

f = [](dual x, dual y) -> dual { return x *= (2.0 / y); };
REQUIRE( f(x, y) == approx(2.0 * x / y) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == 2.0 / y );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(2.0 / y) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(-2.0 * x / (y * y)) );

f = [](dual x, dual y) -> dual { return x *= (2.0 * x); };
REQUIRE( f(x, y) == approx(2.0 * x * x) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(4.0 * x) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == 0.0 );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(0.0) );

f = [](dual x, dual y) -> dual { return x *= (2.0 * y); };
REQUIRE( f(x, y) == approx(2.0 * x * y) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == 2.0 * y );
REQUIRE( derivative(f, wrt(y), at(x, y)) == 2.0 * x );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(2.0 * y) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(2.0 * x) );

f = [](dual x, dual y) -> dual { return x *= x + y; };
REQUIRE( f(x, y) == approx(x * (x + y)) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == 2.0 * x + y );
REQUIRE( derivative(f, wrt(y), at(x, y)) == x );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(2.0 * x + y) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(x) );

f = [](dual x, dual y) -> dual { return x *= x * y; };
REQUIRE( f(x, y) == approx(x * (x * y)) );
Expand All @@ -439,8 +457,8 @@ TEST_CASE("autodiff::dual tests", "[dual]")

f = [](dual x, dual y) -> dual { return x /= 2; };
REQUIRE( f(x, y) == approx(0.5 * x) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == 0.5 );
REQUIRE( derivative(f, wrt(y), at(x, y)) == 0.0 );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(0.5) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(0.0) );

f = [](dual x, dual y) -> dual { return x /= y; };
REQUIRE( f(x, y) == approx(x / y) );
Expand All @@ -449,13 +467,13 @@ TEST_CASE("autodiff::dual tests", "[dual]")

f = [](dual x, dual y) -> dual { return x /= -x; };
REQUIRE( f(x, y) == approx(-1.0) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == 0.0 );
REQUIRE( derivative(f, wrt(y), at(x, y)) == 0.0 );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(0.0) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(0.0) );

f = [](dual x, dual y) -> dual { return x /= (2.0 / y); };
REQUIRE( f(x, y) == approx(0.5 * x * y) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == 0.5 * y );
REQUIRE( derivative(f, wrt(y), at(x, y)) == 0.5 * x );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(0.5 * y) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(0.5 * x) );

f = [](dual x, dual y) -> dual { return x /= (2.0 * y); };
REQUIRE( f(x, y) == approx(0.5 * x / y) );
Expand All @@ -469,7 +487,7 @@ TEST_CASE("autodiff::dual tests", "[dual]")

f = [](dual x, dual y) -> dual { return x /= x * y; };
REQUIRE( f(x, y) == approx(1.0 / y) );
REQUIRE( derivative(f, wrt(x), at(x, y)) == 0.0 );
REQUIRE( derivative(f, wrt(x), at(x, y)) == approx(0.0) );
REQUIRE( derivative(f, wrt(y), at(x, y)) == approx(-1.0 / (y * y)) );
}

Expand Down

0 comments on commit 31b0020

Please sign in to comment.