Skip to content

Commit

Permalink
Temporarily remove generalization
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Feb 14, 2022
1 parent 525eac9 commit b0af1e5
Showing 1 changed file with 13 additions and 249 deletions.
262 changes: 13 additions & 249 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/Transforms/Utils/Cloning.h"
#include <algorithm>
#include <cstdlib>
#include <llvm-7/llvm/IR/Instruction.h>
#include <map>
#include <utility>

Expand Down Expand Up @@ -4739,249 +4740,6 @@ class AdjointGenerator
AtlasConj = 114
};

enum BLASType { Void, Integer, Real, RealPointer };
struct BLASArgument {
std::string name;
BLASType type;
};
enum GradientOpType { BLASCall, HelperCall, AddToDiffe };
struct GradientOp {
GradientOpType gradientOpType;
std::string callee;
BLASType returnType;
std::vector<std::string> args;
};
struct ChainRule {
std::vector<GradientOp> gradientOps;
};
struct BLASAdjointGenerator {
std::string name;
BLASType returnType;
std::vector<BLASArgument> signature;
std::map<std::string, ChainRule> chainrules;
std::map<std::string, std::pair<std::string, std::string>> cacheKnowledge;
};

bool generateBLASAdjoint(BLASAdjointGenerator g, IRBuilder<> BuilderZ,
CallInst &call, Function *called, CallInst &newCall,
StringRef funcName,
const std::map<Argument *, bool> &uncacheable_args) {
if (funcName.str().substr(7) != g.name)
return false;
// 1. prepare
auto currentModule = gutils->oldFunc->getParent();
auto calledFunc = call.getCalledFunction();
char scalarTypeName = funcName.str()[6];
std::map<char, Type *> scalarTypeLookup = {
{'s', Type::getFloatTy(call.getContext())},
{'d', Type::getDoubleTy(call.getContext())}};
Type *scalarType = scalarTypeLookup.at(scalarTypeName);
auto size = ConstantExpr::getSizeOf(scalarType);
std::vector<Value *> args(g.signature.size());
std::vector<Type *> types(g.signature.size());
std::vector<bool> activities(g.signature.size(), false);
std::vector<unsigned> actives;
std::map<std::string, unsigned> nm2idx;
std::map<unsigned, std::string> idx2nm;
for (unsigned i = 0; i != g.signature.size(); ++i) {
args[i] = call.getArgOperand(i);
types[i] = args[i]->getType();
idx2nm[i] = g.signature[i].name;
nm2idx[g.signature[i].name] = i;
if (g.signature[i].type == Real || g.signature[i].type == RealPointer) {
activities[i] = !gutils->isConstantValue(args[i]);
if (activities[i]) {
actives.push_back(i);
}
}
}
std::vector<unsigned> cacheArgIndices;
std::map<unsigned, std::vector<unsigned>> dependencies;
for (const auto &[name, chainrule] : g.chainrules) {
unsigned idx = nm2idx[name];
for (const auto &gradientop : chainrule) {
for (const auto &arg : gradientop.args) {
if (nm2idx.find(arg) != nm2idx.end()) {
dependencies[idx].push_back(nm2idx[arg]);
}
}
}
}
for (unsigned i = 0; i != g.signature.size(); ++i) {
if (g.signature[i].type == RealPointer) {
bool modified =
uncacheable_args.find(calledFunc->arg_begin() + i)->second;
bool needed =
std::any_of(dependencies[i].begin(), dependencies[i].end(),
[&](unsigned j) { return activities[j]; });
if (modified && needed) {
cacheArgIndices.push_back(i);
}
}
}
// 2. do cache
Value *cacheValue;
Type *cacheType;
std::map<unsigned, unsigned> sz, inc;
for (const auto &[name, p] : g.cacheKnowledge) {
sz[nm2idx[name]] = nm2idx[p.first];
inc[nm2idx[name]] = nm2idx[p.second];
}
if (!cacheArgIndices.empty()) {
std::vector<Type *> cacheFieldTypes =
std::for_each(cacheArgIndices.begin(), cacheArgIndices.end(),
[&args](unsigned i) { return args[i]->getType(); });
cacheType = StructType::get(call.getContext(), cacheFieldTypes);
cacheValue = UndefValue::get(cacheType);
if (Mode == DerivativeMode::ReverseModeCombined ||
Mode == DerivativeMode::ReverseModePrimal) {
for (unsigned cacheRank = 0; cacheRank != cacheArgIndices.size();
++cacheRank) {
unsigned index = cacheArgIndices[cacheRank];
auto toCopy = gutils->getNewFromOriginal(args[index]);
auto malloc =
CallInst::CreateMalloc(&newCall, size->getType(), scalarType,
size, args[sz[index]], nullptr, "");
auto bitcast = BuilderZ.CreateBitCast(malloc, toCopy->getType());
auto stridedMemcpy = getOrInsertMemcpyStrided(
*currentModule, PointerType::getUnqual(scalarType), 0, 0);
BuilderZ.CreateCall(stridedMemcpy, {bitcast, toCopy, args[sz[index]],
args[inc[index]]});
cacheValue =
BuilderZ.CreateInsertValue(cacheValue, bitcast, cacheRank);
}
gutils->cacheForReverse(BuilderZ, cacheValue,
getIndex(&call, CacheType::Tape));
}
}
// 3. generate gradient
if (Mode == DerivativeMode::ReverseModeCombined ||
Mode == DerivativeMode::ReverseModeGradient) {
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);
std::vector<Value *> newargs =
std::for_each(args.begin(), args.end(), [&](Value *arg) {
return lookup(gutils->getNewFromOriginal(arg), Builder2);
});
if (!cacheArgIndices.empty()) {
if (Mode == DerivativeMode::ReverseModeGradient) {
cacheValue = BuilderZ.CreatePHI(cacheType, 0);
}
cacheValue =
lookup(gutils->cacheForReverse(BuilderZ, cacheValue,
getIndex(&call, CacheType::Tape)),
Builder2);
for (unsigned cacheRank = 0; cacheRank != cacheArgIndices.size();
++cacheRank) {
unsigned index = cacheArgIndices[cacheRank];
newargs[index] = BuilderZ.CreateExtractValue(cacheValue, cacheRank);
newargs[inc[index]] = ConstantInt::get(types[inc[index]], 1);
}
}
for (const auto &index : actives) {
std::string name = idx2nm[index];
const std::vector<GradientOp> &gradientops =
g.chainrules[name].gradientops;
std::vector<Value *> intermediateResults(gradientops.size(), nullptr);
unsigned stepIndex = 0;
for (const GradientOp &gradientop : gradientops) {
std::vector<Type *> argTypes;
std::vector<Value *> argValues;
Type *returnType;
if (gradientop.returnType == Void) {
returnType = Type::getVoidTy(call.getContext());
} else if (gradientop.returnType == Real) {
returnType = scalarType;
}
bool skip = false;
for (unsigned i = 0; i != gradientop.args.size(); ++i) {
std::string s = gradientop.args[i];
if (nm2idx.find(s) != nm2idx.end()) {
auto index = nm2idx[s];
argTypes[i] = types[index];
argValues[i] = newargs[index];
} else if (s[0] == '_' &&
(nm2idx.find(s.substr(1)) != nm2idx.end())) {
auto index = nm2idx[s.substr(1)];
if (activities[index]) {
argTypes[i] = types[index];
argValues[i] =
lookup(gutils->invertPointerM(args[i], Builder2), Builder2);
} else {
skip = true;
}
} else if (s == "1.0") {
argTypes[i] = scalarType;
argValues[i] = ConstantFP::get(scalarType, 1.0);
} else if (s[0] == '%') {
auto imindex = std::stoi(s.substr(1));
argValues[i] = intermediateResults[imindex];
argTypes[i] = argValues[i]->getType();
} else {
assert(false && "Not handled");
}
}
if (skip)
continue;
switch (gradientop.gradientOpType) {
case BLASCall: {
auto functionType = FunctionType::get(returnType, argTypes, false);
auto functionCallee = currentModule->getOrInsertFunction(
gradientop.callee, functionType);
auto call = Builder2.CreateCall(functionCallee, argValues);
intermediateResults[stepIndex] = call;
} break;
case AddToDiffe:
addToDiffe(argValues[0], argValues[1], Builder2, scalarType);
break;
default:
assert(false && "Unreachable");
}
stepIndex += 1;
}
}
}
if (gutils->knownRecomputeHeuristic.find(&call) !=
gutils->knownRecomputeHeuristic.end()) {
if (!gutils->knownRecomputeHeuristic[&call]) {
gutils->cacheForReverse(BuilderZ, &newCall,
getIndex(&call, CacheType::Self));
}
}

if (Mode == DerivativeMode::ReverseModeGradient) {
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
} else {
eraseIfUnused(call);
}
return true;
}

bool handleBLAS(CallInst &call, Function *called, StringRef funcName,
const std::map<Argument *, bool> &uncacheable_args, bool) {
assert((Mode != DerivativeMode::ForwardMode &&
Mode != DerivativeMode::ForwardModeSplit) &&
"Forward mode is not yet handled with BLAS");
CallInst *const newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
IRBuilder<> BuilderZ(newCall);
BuilderZ.setFastMathFlags(getFast());
auto module = gutils->oldFunc->getParent();
auto &context = call.getContext();
Type *voidType = Type::getVoidTy(context);
Type *scalarType = getBLASInnerType(funcName, context);
bool handled = false;

BLASAdjointGenerator swap{.name = "swap",
.returnType = Void,
.signature = {{"n", Integer},
{"x", RealPointer},
{"incx", Integer},
{"y", RealPointer},
{"incy", Integer}},
.chainrules = {{"x", ChainRule()}}};
std::vector<BLASAdjointGenerator> adjointGenerators;
}

bool handleBLAS(CallInst &call, Function *called, StringRef funcName,
const std::map<Argument *, bool> &uncacheable_args) {
// Forward Mode not handled yet
Expand All @@ -4998,7 +4756,8 @@ class AdjointGenerator
Type *scalarType = getBLASInnerType(funcName, context);
bool handled = false;

if (funcName == "cblas_ddot" || funcName == "cblas_sdot") {
if ((funcName == "cblas_ddot" || funcName == "cblas_sdot") &&
called->isDeclaration()) {
// double sdot(int n, float *x, int incx, float *y, int incy)
handled = true;
std::string axpyName = getBLASName(funcName, "axpy");
Expand Down Expand Up @@ -5178,7 +4937,8 @@ class AdjointGenerator
}
}

if (funcName == "cblas_sswap" || funcName == "cblas_dswap") {
if ((funcName == "cblas_sswap" || funcName == "cblas_dswap") &&
called->isDeclaration()) {
// sswap(int n, float *x, int incx, float *y, int incy)
handled = true;
auto arg_n = call.getArgOperand(0), arg_x = call.getArgOperand(1),
Expand Down Expand Up @@ -5218,7 +4978,8 @@ class AdjointGenerator
}
}

if (funcName == "cblas_scopy" || funcName == "cblas_dcopy") {
if ((funcName == "cblas_scopy" || funcName == "cblas_dcopy") &&
called->isDeclaration()) {
// void scopy(int n, float *x, int incx, float *y, int incy)
handled = true;
std::string axpyName = getBLASName(funcName, "axpy");
Expand Down Expand Up @@ -5262,7 +5023,8 @@ class AdjointGenerator
}
}

if (funcName == "cblas_sscal" || funcName == "cblas_dscal") {
if ((funcName == "cblas_sscal" || funcName == "cblas_dscal") &&
called->isDeclaration()) {
// sscal(int n, float alpha, float *x, int incx)
handled = true;
std::string dotName = getBLASName(funcName, "dot"),
Expand Down Expand Up @@ -5365,7 +5127,8 @@ class AdjointGenerator
}
}

if (funcName == "cblas_saxpy" || funcName == "cblas_daxpy") {
if ((funcName == "cblas_saxpy" || funcName == "cblas_daxpy") &&
called->isDeclaration()) {
// saxpy(int n, float alpha, float *x, int incx, float *y, int incy)
handled = true;
std::string dotName = getBLASName(funcName, "dot");
Expand Down Expand Up @@ -5481,7 +5244,8 @@ class AdjointGenerator
}
}

if (funcName == "cblas_snrm2" || funcName == "cblas_dnrm2") {
if ((funcName == "cblas_snrm2" || funcName == "cblas_dnrm2") &&
called->isDeclaration()) {
// snrm2(int n, float *x, int incx)
handled = true;
std::string axpyName = getBLASName(funcName, "axpy");
Expand Down

0 comments on commit b0af1e5

Please sign in to comment.