From 44e088c3c589163bfeabf63900edb83c98901250 Mon Sep 17 00:00:00 2001 From: Jan-Grimo Sobez Date: Thu, 19 Aug 2021 11:13:36 +0200 Subject: [PATCH] Remove grad and gradx members from variable exprs --- autodiff/reverse/var/eigen.hpp | 45 ++++++++++----- autodiff/reverse/var/var.hpp | 100 ++++++++++++--------------------- 2 files changed, 67 insertions(+), 78 deletions(-) diff --git a/autodiff/reverse/var/eigen.hpp b/autodiff/reverse/var/eigen.hpp index a9fe168f..223162c9 100644 --- a/autodiff/reverse/var/eigen.hpp +++ b/autodiff/reverse/var/eigen.hpp @@ -103,54 +103,71 @@ auto gradient(const Variable& y, Eigen::DenseBase& x) constexpr auto MaxRows = X::MaxRowsAtCompileTime; const auto n = x.size(); + using Gradient = Vec; + Gradient g = Gradient::Zero(n); + for(auto i = 0; i < n; ++i) - x[i].seed(); + x[i].expr->bind_value(&g[i]); y.expr->propagate(1.0); - Vec g(n); for(auto i = 0; i < n; ++i) - g[i] = val(x[i].grad()); + x[i].expr->bind_value(nullptr); return g; } /// Return the Hessian matrix of variable y with respect to variables x. -template -auto hessian(const Variable& y, Eigen::DenseBase& x, Vec& g) +template +auto hessian(const Variable& y, Eigen::DenseBase& x, GradientVec& g) { using U = VariableValueType; using ScalarX = typename X::Scalar; static_assert(isVariable, "Argument x is not a vector with Variable (aka var) objects."); - using ScalarG = typename Vec::Scalar; + using ScalarG = typename GradientVec::Scalar; static_assert(std::is_same_v, "Argument g does not have the same arithmetic type as y."); constexpr auto Rows = X::RowsAtCompileTime; constexpr auto MaxRows = X::MaxRowsAtCompileTime; const auto n = x.size(); + + // Form a vector containing gradient expressions for each variable + using ExpressionGradient = Vec; + ExpressionGradient G(n); + for(auto k = 0; k < n; ++k) - x[k].seedx(); + x[k].expr->bind_expr(&G(k).expr); + /* Build a full gradient expression in DFS tree traversal, updating + * gradient expressions when encountering variables + */ y.expr->propagatex(constant(1.0)); + for(auto k = 0; k < n; ++k) { + x[k].expr->bind_expr(nullptr); + } + + // Read the gradient value from gradient expressions' cached values g.resize(n); for(auto i = 0; i < n; ++i) - g[i] = val(x[i].gradx()); + g[i] = val(G[i]); - Mat H(n, n); + // Form a numeric hessian using the gradient expressions + using Hessian = Mat; + Hessian H = Hessian::Zero(n, n); for(auto i = 0; i < n; ++i) { for(auto k = 0; k < n; ++k) - x[k].seed(); + x[k].expr->bind_value(&H(i, k)); - auto dydxi = x[i].gradx(); - dydxi->propagate(1.0); + // Propagate a second derivative value calculation down the gradient expression tree for variable i + G[i].expr->propagate(1.0); - for(auto j = i; j < n; ++j) - H(i, j) = H(j, i) = val(x[j].grad()); + for(auto k = 0; k < n; ++k) + x[k].expr->bind_value(nullptr); } return H; diff --git a/autodiff/reverse/var/var.hpp b/autodiff/reverse/var/var.hpp index 384da06e..eb74ecad 100644 --- a/autodiff/reverse/var/var.hpp +++ b/autodiff/reverse/var/var.hpp @@ -35,6 +35,7 @@ #include #include #include +#include // autodiff includes #include @@ -254,6 +255,9 @@ struct Expr /// Destructor (to avoid warning) virtual ~Expr() {} + virtual void bind_value(T* /* grad */) {} + virtual void bind_expr(ExprPtr* /* gradx */) {} + /// Update the contribution of this expression in the derivative of the root node of the expression tree. /// @param wprime The derivative of the root expression node w.r.t. the child expression of this expression node. virtual void propagate(const T& wprime) = 0; @@ -268,37 +272,35 @@ template struct VariableExpr : Expr { /// The derivative of the root expression node with respect to this variable. - T grad = {}; + T* gradPtr = {}; /// The derivative of the root expression node with respect to this variable (as an expression for higher-order derivatives). - ExprPtr gradx = {}; + ExprPtr* gradxPtr = {}; /// Construct a VariableExpr object with given value. VariableExpr(const T& v) : Expr(v) {} + + virtual void bind_value(T* grad) { gradPtr = grad; } + virtual void bind_expr(ExprPtr* gradx) { gradxPtr = gradx; } }; /// The node in the expression tree representing an independent variable. template struct IndependentVariableExpr : VariableExpr { - // Using declarations for data members of base class - using VariableExpr::grad; - using VariableExpr::gradx; + using VariableExpr::gradPtr; + using VariableExpr::gradxPtr; /// Construct an IndependentVariableExpr object with given value. - IndependentVariableExpr(const T& v) : VariableExpr(v) - { - gradx = constant(0.0); // TODO: Check if this can be done at the seed function. - } + IndependentVariableExpr(const T& v) : VariableExpr(v) {} - virtual void propagate(const T& wprime) - { - grad += wprime; + virtual void propagate(const T& wprime) { + if(gradPtr) { *gradPtr += wprime; } } virtual void propagatex(const ExprPtr& wprime) { - gradx = gradx + wprime; + if(gradxPtr) { *gradxPtr = *gradxPtr + wprime; } } }; @@ -306,28 +308,24 @@ struct IndependentVariableExpr : VariableExpr template struct DependentVariableExpr : VariableExpr { - // Using declarations for data members of base class - using VariableExpr::grad; - using VariableExpr::gradx; + using VariableExpr::gradPtr; + using VariableExpr::gradxPtr; /// The expression tree that defines how the dependent variable is calculated. ExprPtr expr; /// Construct an DependentVariableExpr object with given value. - DependentVariableExpr(const ExprPtr& e) : VariableExpr(e->val), expr(e) - { - gradx = constant(0.0); // TODO: Check if this can be done at the seed function. - } + DependentVariableExpr(const ExprPtr& e) : VariableExpr(e->val), expr(e) {} virtual void propagate(const T& wprime) { - grad += wprime; + if(gradPtr) { *gradPtr += wprime; } expr->propagate(wprime); } virtual void propagatex(const ExprPtr& wprime) { - gradx = gradx + wprime; + if(gradxPtr) { *gradxPtr = *gradxPtr + wprime; } expr->propagatex(wprime); } }; @@ -1065,21 +1063,6 @@ struct Variable /// Default copy assignment Variable &operator=(const Variable &) = default; - /// Return a pointer to the underlying VariableExpr object in this variable. - auto __variableExpr() const { return static_cast*>(expr.get()); } - - /// Return the derivative value stored in this variable. - auto grad() const { return __variableExpr()->grad; } - - /// Return the derivative expression stored in this variable. - auto gradx() const { return __variableExpr()->gradx; } - - /// Reeet the derivative value stored in this variable to zero. - auto seed() { __variableExpr()->grad = 0; } - - /// Reeet the derivative expression stored in this variable to zero expression. - auto seedx() { __variableExpr()->gradx = constant(0); } - /// Implicitly convert this Variable object into an expression pointer. operator ExprPtr() const { return expr; } @@ -1288,37 +1271,22 @@ auto wrt(Args&&... args) return Wrt{ std::forward_as_tuple(std::forward(args)...) }; } -/// Seed each variable in the **wrt** list. -template -auto seed(const Wrt& wrt) +/// Return the derivatives of a dependent variable y with respect given independent variables. +template +auto derivatives(const Variable& y, const Wrt& wrt) { - constexpr static auto N = sizeof...(Vars); - For([&](auto i) constexpr { - std::get(wrt.args).seed(); - }); -} + constexpr auto N = sizeof...(Vars); + std::array values; + values.fill(0.0); -/// Seed each variable in the **wrt** list. -template -auto seedx(const Wrt& wrt) -{ - constexpr static auto N = sizeof...(Vars); For([&](auto i) constexpr { - std::get(wrt.args).seedx(); + std::get(wrt.args).expr->bind_value(&values.at(i)); }); -} -/// Return the derivatives of a dependent variable y with respect given independent variables. -template -auto derivatives(const Variable& y, const Wrt& wrt) -{ - seed(wrt); y.expr->propagate(1.0); - constexpr static auto N = sizeof...(Vars); - std::array values; For([&](auto i) constexpr { - values[i.index] = std::get(wrt.args).grad(); + std::get(wrt.args).expr->bind_value(nullptr); }); return values; @@ -1328,13 +1296,17 @@ auto derivatives(const Variable& y, const Wrt& wrt) template auto derivativesx(const Variable& y, const Wrt& wrt) { - seedx(wrt); + constexpr auto N = sizeof...(Vars); + std::array, N> values; + + For([&](auto i) constexpr { + std::get(wrt.args).expr->bind_expr(&values.at(i).expr); + }); + y.expr->propagatex(constant(1.0)); - constexpr static auto N = sizeof...(Vars); - std::array, N> values; For([&](auto i) constexpr { - values[i.index] = std::get(wrt.args).gradx(); + std::get(wrt.args).expr->bind_expr(nullptr); }); return values;