Skip to content

Commit

Permalink
Better sret (rust-lang#708)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 24, 2022
1 parent f6fed13 commit 68db988
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 3 deletions.
25 changes: 22 additions & 3 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,14 +1254,19 @@ class Enzyme : public ModulePass {
#endif
CI->replaceAllUsesWith(cload);
} else {
llvm::errs() << *CI << " - " << *diffret << "\n";
assert(0 && " what");
EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI,
"Cannot cast return type of gradient ",
*diffret->getType(), *diffret, ", to desired type ",
*CI->getType());
return false;
}
} else if (CI->hasStructRetAttr()) {
Value *sret = CI->getArgOperand(0);
PointerType *stype = cast<PointerType>(sret->getType());
StructType *st = dyn_cast<StructType>(stype->getElementType());

// Assign results to struct allocated at the call site.
if (StructType *st = cast<StructType>(diffret->getType())) {
if (st && st->isLayoutIdentical(diffretsty)) {
for (unsigned int i = 0; i < st->getNumElements(); i++) {
#if LLVM_VERSION_MAJOR > 7
Value *sgep = Builder.CreateStructGEP(
Expand All @@ -1271,6 +1276,20 @@ class Enzyme : public ModulePass {
#endif
Builder.CreateStore(Builder.CreateExtractValue(diffret, {i}), sgep);
}
} else {
auto &DL = fn->getParent()->getDataLayout();
if (DL.getTypeSizeInBits(stype->getElementType()) !=
DL.getTypeSizeInBits(diffret->getType())) {
EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI,
"Cannot cast return type of gradient ",
*diffret->getType(), *diffret, ", to desired type ",
*stype->getElementType());
return false;
}
Builder.CreateStore(
diffret, Builder.CreatePointerCast(
sret, PointerType::get(diffret->getType(),
stype->getAddressSpace())));
}
} else {

Expand Down
38 changes: 38 additions & 0 deletions enzyme/test/Integration/ReverseMode/sret.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S

#include "test_utils.h"
#include <iostream>
#include <sstream>
#include <utility>

typedef struct {
double df[3];
} Gradient;
extern Gradient __enzyme_autodiff(void*, double, double , double);

double myfunction(double x, double y, double z){
return x * y * z;
}

void dmyfunction(double x, double y, double z, double* res) {
Gradient g = __enzyme_autodiff((void*)myfunction, x, y, z);

res[0]=g.df[0];
res[1]=g.df[1];
res[2]=g.df[2];
}

int main() {
double *res=new double(3);
dmyfunction(3,4,5,res);
APPROX_EQ(res[0], 4*5., 1e-7);
APPROX_EQ(res[1], 3*5., 1e-7);
APPROX_EQ(res[1], 3*4., 1e-7);
}

0 comments on commit 68db988

Please sign in to comment.