Skip to content

Commit

Permalink
init mem args
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Ullrich committed Dec 8, 2022
1 parent 86799e8 commit 8a56cf0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
2 changes: 1 addition & 1 deletion dialects/autodiff/auxiliary/autodiff_rewrite_inner.cpp
Expand Up @@ -112,7 +112,7 @@ const Def* AutoDiffEval::augment_extract(const Extract* ext, Lam* f, Lam* f_diff
auto aug_tuple = augment(tuple, f, f_diff);
auto aug_index = augment(index, f, f_diff);

world.DLOG("tuple was: {} : {}", tuple, tuple->type());
world.DLOG("tuple was: {} : {} [{}]", tuple, tuple->type(), tuple->node_name());
world.DLOG("aug tuple: {} : {}", aug_tuple, aug_tuple->type());
auto aug_ext = world.extract(aug_tuple, aug_index);

Expand Down
38 changes: 37 additions & 1 deletion dialects/autodiff/auxiliary/mem/autodiff_mem.cpp
@@ -1,3 +1,5 @@
#include <cassert>

#include <thorin/axiom.h>
#include <thorin/def.h>
#include <thorin/lam.h>
Expand Down Expand Up @@ -60,6 +62,34 @@ const Def* AutoDiffEval::autodiff_zero(const Def* mem, const Def* def) {
assert(false && "unhandled type in autodiff_zero");
}

const Def* AutoDiffEval::preparePtr(const Def* mem, const Def* darg, Lam* f) {
auto& world = darg->world();
if (auto ptr = match<mem::Ptr>(darg->type())) {
auto [ptr_ty, addr_space] = ptr->args<2>();
// auto [mem2, gradient_ptr] = mem::op_alloc(ptr_ty, mem, world.dbg(darg->name() +
// "_gradient_arr"))->projs<2>(); mem = mem2; gradient_ptrs[darg] = gradient_ptr;
world.DLOG("preparePtr: {} : {}", darg, darg->type());
world.DLOG(" pointer type: {}", ptr_ty);
auto pb_ty = shadow_array_type(ptr_ty, f->dom(0_s));
world.DLOG(" pb ptr type: {}", pb_ty);

auto [mem2, pullback_ptr] = mem::op_malloc(pb_ty, mem, world.dbg(darg->name() + "_pullback_alloc"))->projs<2>();
world.DLOG(" pb type: {}", pullback_ptr->type());
shadow_pullback[darg] = pullback_ptr;
world.DLOG(" set pb for {} : {}", darg, darg->type());
// TODO: init pullback
mem = mem2;
} else if (darg->num_projs() > 1) {
for (auto arg : darg->projs()) {
// auto arg_ty = arg->type();
mem = preparePtr(mem, arg, f);
}
} else {
// world.DLOG("flat type: {}", darg->type());
}
return mem;
}

void AutoDiffEval::prepareMemArguments(Lam* lam, Lam* deriv) {
const Def* deriv_mem = mem::mem_var(deriv);
if (!deriv_mem) return;
Expand All @@ -70,6 +100,11 @@ void AutoDiffEval::prepareMemArguments(Lam* lam, Lam* deriv) {

// TODO: go deeper

world.DLOG("prepareMemArguments: {}", deriv_arg->type());

current_mem = preparePtr(current_mem, deriv_arg, lam);
// assert(0);

// for (auto arg : deriv_arg->projs()) {
// auto arg_ty = arg->type();
// if (auto ptr = match<mem::Ptr>(arg_ty)) {
Expand All @@ -86,7 +121,8 @@ void AutoDiffEval::prepareMemArguments(Lam* lam, Lam* deriv) {
// TODO: test if this works as intended
// deriv_mem |-> current_mem
// Alternatively to replace_mem, a subst call could be used.
augmented[lam->var()] = mem::replace_mem(current_mem, deriv->var());
augmented[lam->var()] = mem::replace_mem(current_mem, deriv->var());
shadow_pullback[augmented[lam->var()]] = shadow_pullback[deriv->var()];
}

const Def* AutoDiffEval::wrap_call_pullbacks(const Def* arg_pb, const Def* arg) {
Expand Down
5 changes: 4 additions & 1 deletion dialects/autodiff/auxiliary/mem/autodiff_mem_axioms.cpp
Expand Up @@ -31,7 +31,7 @@ const Def* AutoDiffEval::augment_lea(const App* lea, Lam* f, Lam* f_diff) {
if (gradient_array) {
gradient_ptrs[aug_lea] = mem::op_lea(gradient_array, aug_idx, w.dbg("pullback_lea"));
} else {
// TODO: incorporate gradient_ptrs in shadow, remove cases
w.DLOG("lea aug_ptr {} : {}", aug_ptr, aug_ptr->type());
auto pullback_array = shadow_pullback[aug_ptr];
assert(pullback_array);
shadow_pullback[aug_lea] = mem::op_lea(pullback_array, aug_idx, w.dbg("pullback_lea"));
Expand Down Expand Up @@ -110,6 +110,9 @@ const Def* AutoDiffEval::augment_alloc(const App* alloc, Lam* f, Lam* f_diff) {
// TODO: check if this should be gradient_ptrs
shadow_pullback[alloc_ptr] = pullback_ptr;

// We do not need an init here as the pullback will be present iff data in the pointer is present.
// Therefore, a store will always happen befoire the first load.

auto tup = world.tuple({alloc_mem_2, alloc_ptr});

// TODO: correct pullbacks instead
Expand Down
2 changes: 2 additions & 0 deletions dialects/autodiff/passes/autodiff_eval.h
Expand Up @@ -41,6 +41,8 @@ class AutoDiffEval : public RWPass<AutoDiffEval, Lam> {
/// This function generates the structure for the function arguments.
void prepareArguments(Lam* lam, Lam* deriv);

const Def* preparePtr(const Def* mem, const Def* arg, Lam* f);

// TODO: comment
const Def* buildAugmentedTuple(World& world, Defs aug_ops, const Pi* pb_ty, Lam* f, Lam* f_diff);

Expand Down

0 comments on commit 8a56cf0

Please sign in to comment.