Skip to content

Commit

Permalink
Revert "Remove grad from variable exprs"
Browse files Browse the repository at this point in the history
This reverts parts of commit 44e088c.

The var.grad() api is more convenient than requiring to first bind a
float and then reading it back. for instance, expr.propagate() and
var.grad() can be used in unit tests. see also
autodiff#177 (comment)
  • Loading branch information
adam-ce committed Aug 25, 2021
1 parent 2309fe0 commit 2216853
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
14 changes: 6 additions & 8 deletions autodiff/reverse/var/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,14 @@ auto gradient(const Variable<T>& y, Eigen::DenseBase<X>& x)
constexpr auto MaxRows = X::MaxRowsAtCompileTime;

const auto n = x.size();
using Gradient = Vec<U, Rows, MaxRows>;
Gradient g = Gradient::Zero(n);

for(auto i = 0; i < n; ++i)
x[i].expr->bind_value(&g[i]);
x[i].seed();

y.expr->propagate(1.0);

Vec<U, Rows, MaxRows> g(n);
for(auto i = 0; i < n; ++i)
x[i].expr->bind_value(nullptr);
g[i] = val(x[i].grad());

return g;
}
Expand Down Expand Up @@ -161,13 +159,13 @@ auto hessian(const Variable<T>& y, Eigen::DenseBase<X>& x, GradientVec& g)
for(auto i = 0; i < n; ++i)
{
for(auto k = 0; k < n; ++k)
x[k].expr->bind_value(&H(i, k));
x[k].seed();

// Propagate a second derivative value calculation down the gradient expression tree for variable i
G[i].expr->propagate(1.0);

for(auto k = 0; k < n; ++k)
x[k].expr->bind_value(nullptr);
for(auto j = i; j < n; ++j)
H(i, j) = H(j, i) = val(x[j].grad());
}

return H;
Expand Down
48 changes: 30 additions & 18 deletions autodiff/reverse/var/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ struct Expr
virtual ~Expr() {}

/// Bind a value pointer for writing the derivative during propagation
virtual void bind_value(T* /* grad */) {}
/// Bind an expression pointer for writing the derivative expression during propagation
virtual void bind_expr(ExprPtr<T>* /* gradx */) {}

Expand All @@ -273,30 +272,29 @@ template<typename T>
struct VariableExpr : Expr<T>
{
/// The derivative value of the root expression node w.r.t. this variable.
T* gradPtr = {};
T grad = {};

/// The derivative expression of the root expression node w.r.t. this variable (reusable for higher-order derivatives).
ExprPtr<T>* gradxPtr = {};

/// Construct a VariableExpr object with given value.
VariableExpr(const T& v) : Expr<T>(v) {}

virtual void bind_value(T* grad) { gradPtr = grad; }
virtual void bind_expr(ExprPtr<T>* gradx) { gradxPtr = gradx; }
};

/// The node in the expression tree representing an independent variable.
template<typename T>
struct IndependentVariableExpr : VariableExpr<T>
{
using VariableExpr<T>::gradPtr;
using VariableExpr<T>::gradxPtr;
using VariableExpr<T>::grad;

/// Construct an IndependentVariableExpr object with given value.
IndependentVariableExpr(const T& v) : VariableExpr<T>(v) {}

virtual void propagate(const T& wprime) {
if(gradPtr) { *gradPtr += wprime; }
virtual void propagate(const T& wprime)
{
grad += wprime;
}

virtual void propagatex(const ExprPtr<T>& wprime)
Expand All @@ -309,8 +307,8 @@ struct IndependentVariableExpr : VariableExpr<T>
template<typename T>
struct DependentVariableExpr : VariableExpr<T>
{
using VariableExpr<T>::gradPtr;
using VariableExpr<T>::gradxPtr;
using VariableExpr<T>::grad;

/// The expression tree that defines how the dependent variable is calculated.
ExprPtr<T> expr;
Expand All @@ -320,7 +318,7 @@ struct DependentVariableExpr : VariableExpr<T>

virtual void propagate(const T& wprime)
{
if(gradPtr) { *gradPtr += wprime; }
grad += wprime;
expr->propagate(wprime);
}

Expand Down Expand Up @@ -1064,6 +1062,16 @@ struct Variable
/// Default copy assignment
Variable &operator=(const Variable &) = default;

/// Return a pointer to the underlying VariableExpr object in this variable.
auto __variableExpr() const { return static_cast<VariableExpr<T>*>(expr.get()); }

/// Return the derivative value stored in this variable.
auto grad() const { return __variableExpr()->grad; }


/// Reeet the derivative value stored in this variable to zero.
auto seed() { __variableExpr()->grad = 0; }

/// Implicitly convert this Variable object into an expression pointer.
operator ExprPtr<T>() const { return expr; }

Expand Down Expand Up @@ -1272,22 +1280,26 @@ auto wrt(Args&&... args)
return Wrt<Args&&...>{ std::forward_as_tuple(std::forward<Args>(args)...) };
}

/// Seed each variable in the **wrt** list.
template<typename... Vars>
auto seed(const Wrt<Vars...>& wrt)
{
constexpr static auto N = sizeof...(Vars);
For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).seed();
});
}
/// Return the derivatives of a dependent variable y with respect given independent variables.
template<typename T, typename... Vars>
auto derivatives(const Variable<T>& y, const Wrt<Vars...>& wrt)
{
constexpr auto N = sizeof...(Vars);
std::array<T, N> values;
values.fill(0.0);

For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).expr->bind_value(&values.at(i));
});

seed(wrt);
y.expr->propagate(1.0);

constexpr static auto N = sizeof...(Vars);
std::array<T, N> values;
For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).expr->bind_value(nullptr);
values[i.index] = std::get<i>(wrt.args).grad();
});

return values;
Expand Down

0 comments on commit 2216853

Please sign in to comment.