Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove grad and gradx members from variable exprs #177

Merged
merged 12 commits into from
Aug 24, 2021
45 changes: 31 additions & 14 deletions autodiff/reverse/var/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,54 +103,71 @@ auto gradient(const Variable<T>& y, Eigen::DenseBase<X>& x)
constexpr auto MaxRows = X::MaxRowsAtCompileTime;

const auto n = x.size();
using Gradient = Vec<U, Rows, MaxRows>;
Gradient g = Gradient::Zero(n);

for(auto i = 0; i < n; ++i)
x[i].seed();
x[i].expr->bind_value(&g[i]);

y.expr->propagate(1.0);

Vec<U, Rows, MaxRows> g(n);
for(auto i = 0; i < n; ++i)
g[i] = val(x[i].grad());
x[i].expr->bind_value(nullptr);

return g;
}

/// Return the Hessian matrix of variable y with respect to variables x.
template<typename T, typename X, typename Vec>
auto hessian(const Variable<T>& y, Eigen::DenseBase<X>& x, Vec& g)
template<typename T, typename X, typename GradientVec>
auto hessian(const Variable<T>& y, Eigen::DenseBase<X>& x, GradientVec& g)
{
using U = VariableValueType<T>;

using ScalarX = typename X::Scalar;
static_assert(isVariable<ScalarX>, "Argument x is not a vector with Variable<T> (aka var) objects.");

using ScalarG = typename Vec::Scalar;
using ScalarG = typename GradientVec::Scalar;
static_assert(std::is_same_v<U, ScalarG>, "Argument g does not have the same arithmetic type as y.");

constexpr auto Rows = X::RowsAtCompileTime;
constexpr auto MaxRows = X::MaxRowsAtCompileTime;

const auto n = x.size();

// Form a vector containing gradient expressions for each variable
using ExpressionGradient = Vec<ScalarX, Rows, MaxRows>;
ExpressionGradient G(n);

for(auto k = 0; k < n; ++k)
x[k].seedx();
x[k].expr->bind_expr(&G(k).expr);

/* Build a full gradient expression in DFS tree traversal, updating
* gradient expressions when encountering variables
*/
y.expr->propagatex(constant<T>(1.0));

for(auto k = 0; k < n; ++k) {
x[k].expr->bind_expr(nullptr);
}

// Read the gradient value from gradient expressions' cached values
g.resize(n);
for(auto i = 0; i < n; ++i)
g[i] = val(x[i].gradx());
g[i] = val(G[i]);

Mat<U, Rows, Rows, MaxRows, MaxRows> H(n, n);
// Form a numeric hessian using the gradient expressions
using Hessian = Mat<U, Rows, Rows, MaxRows, MaxRows>;
Hessian H = Hessian::Zero(n, n);
for(auto i = 0; i < n; ++i)
{
for(auto k = 0; k < n; ++k)
x[k].seed();
x[k].expr->bind_value(&H(i, k));

auto dydxi = x[i].gradx();
dydxi->propagate(1.0);
// Propagate a second derivative value calculation down the gradient expression tree for variable i
G[i].expr->propagate(1.0);

for(auto j = i; j < n; ++j)
H(i, j) = H(j, i) = val(x[j].grad());
for(auto k = 0; k < n; ++k)
x[k].expr->bind_value(nullptr);
}

return H;
Expand Down
100 changes: 36 additions & 64 deletions autodiff/reverse/var/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <cmath>
#include <cstddef>
#include <memory>
#include <unordered_map>

// autodiff includes
#include <autodiff/common/meta.hpp>
Expand Down Expand Up @@ -254,6 +255,9 @@ struct Expr
/// Destructor (to avoid warning)
virtual ~Expr() {}

virtual void bind_value(T* /* grad */) {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a short (Doxygen ///) comment here for these bind_* functions. I like your approach.

virtual void bind_expr(ExprPtr<T>* /* gradx */) {}

/// Update the contribution of this expression in the derivative of the root node of the expression tree.
/// @param wprime The derivative of the root expression node w.r.t. the child expression of this expression node.
virtual void propagate(const T& wprime) = 0;
Expand All @@ -268,66 +272,60 @@ template<typename T>
struct VariableExpr : Expr<T>
{
/// The derivative of the root expression node with respect to this variable.
T grad = {};
T* gradPtr = {};

/// The derivative of the root expression node with respect to this variable (as an expression for higher-order derivatives).
ExprPtr<T> gradx = {};
ExprPtr<T>* gradxPtr = {};

/// Construct a VariableExpr object with given value.
VariableExpr(const T& v) : Expr<T>(v) {}

virtual void bind_value(T* grad) { gradPtr = grad; }
virtual void bind_expr(ExprPtr<T>* gradx) { gradxPtr = gradx; }
};

/// The node in the expression tree representing an independent variable.
template<typename T>
struct IndependentVariableExpr : VariableExpr<T>
{
// Using declarations for data members of base class
using VariableExpr<T>::grad;
using VariableExpr<T>::gradx;
using VariableExpr<T>::gradPtr;
using VariableExpr<T>::gradxPtr;

/// Construct an IndependentVariableExpr object with given value.
IndependentVariableExpr(const T& v) : VariableExpr<T>(v)
{
gradx = constant<T>(0.0); // TODO: Check if this can be done at the seed function.
}
IndependentVariableExpr(const T& v) : VariableExpr<T>(v) {}

virtual void propagate(const T& wprime)
{
grad += wprime;
virtual void propagate(const T& wprime) {
if(gradPtr) { *gradPtr += wprime; }
}

virtual void propagatex(const ExprPtr<T>& wprime)
{
gradx = gradx + wprime;
if(gradxPtr) { *gradxPtr = *gradxPtr + wprime; }
}
};

/// The node in the expression tree representing a dependent variable.
template<typename T>
struct DependentVariableExpr : VariableExpr<T>
{
// Using declarations for data members of base class
using VariableExpr<T>::grad;
using VariableExpr<T>::gradx;
using VariableExpr<T>::gradPtr;
using VariableExpr<T>::gradxPtr;

/// The expression tree that defines how the dependent variable is calculated.
ExprPtr<T> expr;

/// Construct an DependentVariableExpr object with given value.
DependentVariableExpr(const ExprPtr<T>& e) : VariableExpr<T>(e->val), expr(e)
{
gradx = constant<T>(0.0); // TODO: Check if this can be done at the seed function.
}
DependentVariableExpr(const ExprPtr<T>& e) : VariableExpr<T>(e->val), expr(e) {}

virtual void propagate(const T& wprime)
{
grad += wprime;
if(gradPtr) { *gradPtr += wprime; }
expr->propagate(wprime);
}

virtual void propagatex(const ExprPtr<T>& wprime)
{
gradx = gradx + wprime;
if(gradxPtr) { *gradxPtr = *gradxPtr + wprime; }
expr->propagatex(wprime);
}
};
Expand Down Expand Up @@ -1057,21 +1055,6 @@ struct Variable
/// Default copy assignment
Variable &operator=(const Variable &) = default;

/// Return a pointer to the underlying VariableExpr object in this variable.
auto __variableExpr() const { return static_cast<VariableExpr<T>*>(expr.get()); }

/// Return the derivative value stored in this variable.
auto grad() const { return __variableExpr()->grad; }

/// Return the derivative expression stored in this variable.
auto gradx() const { return __variableExpr()->gradx; }

/// Reeet the derivative value stored in this variable to zero.
auto seed() { __variableExpr()->grad = 0; }

/// Reeet the derivative expression stored in this variable to zero expression.
auto seedx() { __variableExpr()->gradx = constant<T>(0); }

/// Implicitly convert this Variable object into an expression pointer.
operator ExprPtr<T>() const { return expr; }

Expand Down Expand Up @@ -1280,37 +1263,22 @@ auto wrt(Args&&... args)
return Wrt<Args&&...>{ std::forward_as_tuple(std::forward<Args>(args)...) };
}

/// Seed each variable in the **wrt** list.
template<typename... Vars>
auto seed(const Wrt<Vars...>& wrt)
/// Return the derivatives of a dependent variable y with respect given independent variables.
template<typename T, typename... Vars>
auto derivatives(const Variable<T>& y, const Wrt<Vars...>& wrt)
{
constexpr static auto N = sizeof...(Vars);
For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).seed();
});
}
constexpr auto N = sizeof...(Vars);
std::array<T, N> values;
values.fill(0.0);

/// Seed each variable in the **wrt** list.
template<typename... Vars>
auto seedx(const Wrt<Vars...>& wrt)
{
constexpr static auto N = sizeof...(Vars);
For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).seedx();
std::get<i>(wrt.args).expr->bind_value(&values.at(i));
});
}

/// Return the derivatives of a dependent variable y with respect given independent variables.
template<typename T, typename... Vars>
auto derivatives(const Variable<T>& y, const Wrt<Vars...>& wrt)
{
seed(wrt);
y.expr->propagate(1.0);

constexpr static auto N = sizeof...(Vars);
std::array<T, N> values;
For<N>([&](auto i) constexpr {
values[i.index] = std::get<i>(wrt.args).grad();
std::get<i>(wrt.args).expr->bind_value(nullptr);
});

return values;
Expand All @@ -1320,13 +1288,17 @@ auto derivatives(const Variable<T>& y, const Wrt<Vars...>& wrt)
template<typename T, typename... Vars>
auto derivativesx(const Variable<T>& y, const Wrt<Vars...>& wrt)
{
seedx(wrt);
constexpr auto N = sizeof...(Vars);
std::array<Variable<T>, N> values;

For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).expr->bind_expr(&values.at(i).expr);
});

y.expr->propagatex(constant<T>(1.0));

constexpr static auto N = sizeof...(Vars);
std::array<Variable<T>, N> values;
For<N>([&](auto i) constexpr {
values[i.index] = std::get<i>(wrt.args).gradx();
std::get<i>(wrt.args).expr->bind_expr(nullptr);
});

return values;
Expand Down