Skip to content

Commit

Permalink
Remove grad and gradx members from variable exprs
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-grimo authored and allanleal committed Aug 24, 2021
1 parent 370ca6e commit 44e088c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 78 deletions.
45 changes: 31 additions & 14 deletions autodiff/reverse/var/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,54 +103,71 @@ auto gradient(const Variable<T>& y, Eigen::DenseBase<X>& x)
constexpr auto MaxRows = X::MaxRowsAtCompileTime;

const auto n = x.size();
using Gradient = Vec<U, Rows, MaxRows>;
Gradient g = Gradient::Zero(n);

for(auto i = 0; i < n; ++i)
x[i].seed();
x[i].expr->bind_value(&g[i]);

y.expr->propagate(1.0);

Vec<U, Rows, MaxRows> g(n);
for(auto i = 0; i < n; ++i)
g[i] = val(x[i].grad());
x[i].expr->bind_value(nullptr);

return g;
}

/// Return the Hessian matrix of variable y with respect to variables x.
template<typename T, typename X, typename Vec>
auto hessian(const Variable<T>& y, Eigen::DenseBase<X>& x, Vec& g)
template<typename T, typename X, typename GradientVec>
auto hessian(const Variable<T>& y, Eigen::DenseBase<X>& x, GradientVec& g)
{
using U = VariableValueType<T>;

using ScalarX = typename X::Scalar;
static_assert(isVariable<ScalarX>, "Argument x is not a vector with Variable<T> (aka var) objects.");

using ScalarG = typename Vec::Scalar;
using ScalarG = typename GradientVec::Scalar;
static_assert(std::is_same_v<U, ScalarG>, "Argument g does not have the same arithmetic type as y.");

constexpr auto Rows = X::RowsAtCompileTime;
constexpr auto MaxRows = X::MaxRowsAtCompileTime;

const auto n = x.size();

// Form a vector containing gradient expressions for each variable
using ExpressionGradient = Vec<ScalarX, Rows, MaxRows>;
ExpressionGradient G(n);

for(auto k = 0; k < n; ++k)
x[k].seedx();
x[k].expr->bind_expr(&G(k).expr);

/* Build a full gradient expression in DFS tree traversal, updating
* gradient expressions when encountering variables
*/
y.expr->propagatex(constant<T>(1.0));

for(auto k = 0; k < n; ++k) {
x[k].expr->bind_expr(nullptr);
}

// Read the gradient value from gradient expressions' cached values
g.resize(n);
for(auto i = 0; i < n; ++i)
g[i] = val(x[i].gradx());
g[i] = val(G[i]);

Mat<U, Rows, Rows, MaxRows, MaxRows> H(n, n);
// Form a numeric hessian using the gradient expressions
using Hessian = Mat<U, Rows, Rows, MaxRows, MaxRows>;
Hessian H = Hessian::Zero(n, n);
for(auto i = 0; i < n; ++i)
{
for(auto k = 0; k < n; ++k)
x[k].seed();
x[k].expr->bind_value(&H(i, k));

auto dydxi = x[i].gradx();
dydxi->propagate(1.0);
// Propagate a second derivative value calculation down the gradient expression tree for variable i
G[i].expr->propagate(1.0);

for(auto j = i; j < n; ++j)
H(i, j) = H(j, i) = val(x[j].grad());
for(auto k = 0; k < n; ++k)
x[k].expr->bind_value(nullptr);
}

return H;
Expand Down
100 changes: 36 additions & 64 deletions autodiff/reverse/var/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <cmath>
#include <cstddef>
#include <memory>
#include <unordered_map>

// autodiff includes
#include <autodiff/common/meta.hpp>
Expand Down Expand Up @@ -254,6 +255,9 @@ struct Expr
/// Destructor (to avoid warning)
virtual ~Expr() {}

virtual void bind_value(T* /* grad */) {}
virtual void bind_expr(ExprPtr<T>* /* gradx */) {}

/// Update the contribution of this expression in the derivative of the root node of the expression tree.
/// @param wprime The derivative of the root expression node w.r.t. the child expression of this expression node.
virtual void propagate(const T& wprime) = 0;
Expand All @@ -268,66 +272,60 @@ template<typename T>
struct VariableExpr : Expr<T>
{
/// The derivative of the root expression node with respect to this variable.
T grad = {};
T* gradPtr = {};

/// The derivative of the root expression node with respect to this variable (as an expression for higher-order derivatives).
ExprPtr<T> gradx = {};
ExprPtr<T>* gradxPtr = {};

/// Construct a VariableExpr object with given value.
VariableExpr(const T& v) : Expr<T>(v) {}

virtual void bind_value(T* grad) { gradPtr = grad; }
virtual void bind_expr(ExprPtr<T>* gradx) { gradxPtr = gradx; }
};

/// The node in the expression tree representing an independent variable.
template<typename T>
struct IndependentVariableExpr : VariableExpr<T>
{
// Using declarations for data members of base class
using VariableExpr<T>::grad;
using VariableExpr<T>::gradx;
using VariableExpr<T>::gradPtr;
using VariableExpr<T>::gradxPtr;

/// Construct an IndependentVariableExpr object with given value.
IndependentVariableExpr(const T& v) : VariableExpr<T>(v)
{
gradx = constant<T>(0.0); // TODO: Check if this can be done at the seed function.
}
IndependentVariableExpr(const T& v) : VariableExpr<T>(v) {}

virtual void propagate(const T& wprime)
{
grad += wprime;
virtual void propagate(const T& wprime) {
if(gradPtr) { *gradPtr += wprime; }
}

virtual void propagatex(const ExprPtr<T>& wprime)
{
gradx = gradx + wprime;
if(gradxPtr) { *gradxPtr = *gradxPtr + wprime; }
}
};

/// The node in the expression tree representing a dependent variable.
template<typename T>
struct DependentVariableExpr : VariableExpr<T>
{
// Using declarations for data members of base class
using VariableExpr<T>::grad;
using VariableExpr<T>::gradx;
using VariableExpr<T>::gradPtr;
using VariableExpr<T>::gradxPtr;

/// The expression tree that defines how the dependent variable is calculated.
ExprPtr<T> expr;

/// Construct an DependentVariableExpr object with given value.
DependentVariableExpr(const ExprPtr<T>& e) : VariableExpr<T>(e->val), expr(e)
{
gradx = constant<T>(0.0); // TODO: Check if this can be done at the seed function.
}
DependentVariableExpr(const ExprPtr<T>& e) : VariableExpr<T>(e->val), expr(e) {}

virtual void propagate(const T& wprime)
{
grad += wprime;
if(gradPtr) { *gradPtr += wprime; }
expr->propagate(wprime);
}

virtual void propagatex(const ExprPtr<T>& wprime)
{
gradx = gradx + wprime;
if(gradxPtr) { *gradxPtr = *gradxPtr + wprime; }
expr->propagatex(wprime);
}
};
Expand Down Expand Up @@ -1065,21 +1063,6 @@ struct Variable
/// Default copy assignment
Variable &operator=(const Variable &) = default;

/// Return a pointer to the underlying VariableExpr object in this variable.
auto __variableExpr() const { return static_cast<VariableExpr<T>*>(expr.get()); }

/// Return the derivative value stored in this variable.
auto grad() const { return __variableExpr()->grad; }

/// Return the derivative expression stored in this variable.
auto gradx() const { return __variableExpr()->gradx; }

/// Reeet the derivative value stored in this variable to zero.
auto seed() { __variableExpr()->grad = 0; }

/// Reeet the derivative expression stored in this variable to zero expression.
auto seedx() { __variableExpr()->gradx = constant<T>(0); }

/// Implicitly convert this Variable object into an expression pointer.
operator ExprPtr<T>() const { return expr; }

Expand Down Expand Up @@ -1288,37 +1271,22 @@ auto wrt(Args&&... args)
return Wrt<Args&&...>{ std::forward_as_tuple(std::forward<Args>(args)...) };
}

/// Seed each variable in the **wrt** list.
template<typename... Vars>
auto seed(const Wrt<Vars...>& wrt)
/// Return the derivatives of a dependent variable y with respect given independent variables.
template<typename T, typename... Vars>
auto derivatives(const Variable<T>& y, const Wrt<Vars...>& wrt)
{
constexpr static auto N = sizeof...(Vars);
For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).seed();
});
}
constexpr auto N = sizeof...(Vars);
std::array<T, N> values;
values.fill(0.0);

/// Seed each variable in the **wrt** list.
template<typename... Vars>
auto seedx(const Wrt<Vars...>& wrt)
{
constexpr static auto N = sizeof...(Vars);
For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).seedx();
std::get<i>(wrt.args).expr->bind_value(&values.at(i));
});
}

/// Return the derivatives of a dependent variable y with respect given independent variables.
template<typename T, typename... Vars>
auto derivatives(const Variable<T>& y, const Wrt<Vars...>& wrt)
{
seed(wrt);
y.expr->propagate(1.0);

constexpr static auto N = sizeof...(Vars);
std::array<T, N> values;
For<N>([&](auto i) constexpr {
values[i.index] = std::get<i>(wrt.args).grad();
std::get<i>(wrt.args).expr->bind_value(nullptr);
});

return values;
Expand All @@ -1328,13 +1296,17 @@ auto derivatives(const Variable<T>& y, const Wrt<Vars...>& wrt)
template<typename T, typename... Vars>
auto derivativesx(const Variable<T>& y, const Wrt<Vars...>& wrt)
{
seedx(wrt);
constexpr auto N = sizeof...(Vars);
std::array<Variable<T>, N> values;

For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).expr->bind_expr(&values.at(i).expr);
});

y.expr->propagatex(constant<T>(1.0));

constexpr static auto N = sizeof...(Vars);
std::array<Variable<T>, N> values;
For<N>([&](auto i) constexpr {
values[i.index] = std::get<i>(wrt.args).gradx();
std::get<i>(wrt.args).expr->bind_expr(nullptr);
});

return values;
Expand Down

0 comments on commit 44e088c

Please sign in to comment.