Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Remove grad from variable exprs" #181

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions autodiff/reverse/var/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,14 @@ auto gradient(const Variable<T>& y, Eigen::DenseBase<X>& x)
constexpr auto MaxRows = X::MaxRowsAtCompileTime;

const auto n = x.size();
using Gradient = Vec<U, Rows, MaxRows>;
Gradient g = Gradient::Zero(n);

for(auto i = 0; i < n; ++i)
x[i].expr->bind_value(&g[i]);
x[i].seed();

y.expr->propagate(1.0);

Vec<U, Rows, MaxRows> g(n);
for(auto i = 0; i < n; ++i)
x[i].expr->bind_value(nullptr);
g[i] = val(x[i].grad());

return g;
}
Expand Down Expand Up @@ -161,13 +159,13 @@ auto hessian(const Variable<T>& y, Eigen::DenseBase<X>& x, GradientVec& g)
for(auto i = 0; i < n; ++i)
{
for(auto k = 0; k < n; ++k)
x[k].expr->bind_value(&H(i, k));
x[k].seed();

// Propagate a second derivative value calculation down the gradient expression tree for variable i
G[i].expr->propagate(1.0);

for(auto k = 0; k < n; ++k)
x[k].expr->bind_value(nullptr);
for(auto j = i; j < n; ++j)
H(i, j) = H(j, i) = val(x[j].grad());
}

return H;
Expand Down
48 changes: 30 additions & 18 deletions autodiff/reverse/var/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ struct Expr
virtual ~Expr() {}

/// Bind a value pointer for writing the derivative during propagation
virtual void bind_value(T* /* grad */) {}
/// Bind an expression pointer for writing the derivative expression during propagation
virtual void bind_expr(ExprPtr<T>* /* gradx */) {}

Expand All @@ -273,30 +272,29 @@ template<typename T>
struct VariableExpr : Expr<T>
{
/// The derivative value of the root expression node w.r.t. this variable.
T* gradPtr = {};
T grad = {};

/// The derivative expression of the root expression node w.r.t. this variable (reusable for higher-order derivatives).
ExprPtr<T>* gradxPtr = {};

/// Construct a VariableExpr object with given value.
VariableExpr(const T& v) : Expr<T>(v) {}

virtual void bind_value(T* grad) { gradPtr = grad; }
virtual void bind_expr(ExprPtr<T>* gradx) { gradxPtr = gradx; }
};

/// The node in the expression tree representing an independent variable.
template<typename T>
struct IndependentVariableExpr : VariableExpr<T>
{
using VariableExpr<T>::gradPtr;
using VariableExpr<T>::gradxPtr;
using VariableExpr<T>::grad;

/// Construct an IndependentVariableExpr object with given value.
IndependentVariableExpr(const T& v) : VariableExpr<T>(v) {}

virtual void propagate(const T& wprime) {
if(gradPtr) { *gradPtr += wprime; }
virtual void propagate(const T& wprime)
{
grad += wprime;
}

virtual void propagatex(const ExprPtr<T>& wprime)
Expand All @@ -309,8 +307,8 @@ struct IndependentVariableExpr : VariableExpr<T>
template<typename T>
struct DependentVariableExpr : VariableExpr<T>
{
using VariableExpr<T>::gradPtr;
using VariableExpr<T>::gradxPtr;
using VariableExpr<T>::grad;

/// The expression tree that defines how the dependent variable is calculated.
ExprPtr<T> expr;
Expand All @@ -320,7 +318,7 @@ struct DependentVariableExpr : VariableExpr<T>

virtual void propagate(const T& wprime)
{
if(gradPtr) { *gradPtr += wprime; }
grad += wprime;
expr->propagate(wprime);
}

Expand Down Expand Up @@ -1064,6 +1062,16 @@ struct Variable
/// Default copy assignment
Variable &operator=(const Variable &) = default;

/// Return a pointer to the underlying VariableExpr object in this variable.
auto __variableExpr() const { return static_cast<VariableExpr<T>*>(expr.get()); }

/// Return the derivative value stored in this variable.
auto grad() const { return __variableExpr()->grad; }


/// Reeet the derivative value stored in this variable to zero.
auto seed() { __variableExpr()->grad = 0; }

/// Implicitly convert this Variable object into an expression pointer.
operator ExprPtr<T>() const { return expr; }

Expand Down Expand Up @@ -1272,22 +1280,26 @@ auto wrt(Args&&... args)
return Wrt<Args&&...>{ std::forward_as_tuple(std::forward<Args>(args)...) };
}

/// Seed each variable in the **wrt** list.
template<typename... Vars>
auto seed(const Wrt<Vars...>& wrt)
{
constexpr static auto N = sizeof...(Vars);
For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).seed();
});
}
/// Return the derivatives of a dependent variable y with respect given independent variables.
template<typename T, typename... Vars>
auto derivatives(const Variable<T>& y, const Wrt<Vars...>& wrt)
{
constexpr auto N = sizeof...(Vars);
std::array<T, N> values;
values.fill(0.0);

For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).expr->bind_value(&values.at(i));
});

seed(wrt);
y.expr->propagate(1.0);

constexpr static auto N = sizeof...(Vars);
std::array<T, N> values;
For<N>([&](auto i) constexpr {
std::get<i>(wrt.args).expr->bind_value(nullptr);
values[i.index] = std::get<i>(wrt.args).grad();
});

return values;
Expand Down