From 7755f94085b3bf1f002c64fe1175eb51096655ca Mon Sep 17 00:00:00 2001 From: Pratyush Das Date: Mon, 11 Oct 2021 22:55:48 -0400 Subject: [PATCH] getNewFromOriginal --- enzyme/Enzyme/AdjointGenerator.h | 38 +++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 9be916eeaca8..12bccc5f9fac 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -6019,7 +6019,8 @@ class AdjointGenerator if (!Bcache) structarg2 = lookup(gutils->getNewFromOriginal(call.getArgOperand(9)), Builder2); - if (call.getArgOperand(0) == Builder2.getInt32(102)) { + if (gutils->getNewFromOriginal(call.getArgOperand(0)) == + Builder2.getInt32(102)) { if (aactive) { salda = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), Builder2); @@ -6032,7 +6033,8 @@ class AdjointGenerator sbldc = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), Builder2); } - } else if (call.getArgOperand(0) == Builder2.getInt32(101)) { + } else if (gutils->getNewFromOriginal(call.getArgOperand(0)) == + Builder2.getInt32(101)) { if (aactive) { salda = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), Builder2); @@ -6049,16 +6051,20 @@ class AdjointGenerator assert(false && "Wrong value"); CallInst *safunccall, *sbfunccall; if (aactive) { - if (call.getArgOperand(2) == Builder2.getInt32(112) || - call.getArgOperand(2) == Builder2.getInt32(113)) { + if (gutils->getNewFromOriginal(call.getArgOperand(2)) == + Builder2.getInt32(112) || + gutils->getNewFromOriginal(call.getArgOperand(2)) == + Builder2.getInt32(113)) { sabtrans = Builder2.getInt32(111); - if (call.getArgOperand(0) == Builder2.getInt32(102)) + if (gutils->getNewFromOriginal(call.getArgOperand(0)) == + Builder2.getInt32(102)) saldb = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), Builder2); else saldb = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), Builder2); - } else if (call.getArgOperand(2) == Builder2.getInt32(111)) { + } else if (gutils->getNewFromOriginal(call.getArgOperand(2)) == + Builder2.getInt32(111)) { sabtrans = Builder2.getInt32(112); saldb = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), Builder2); @@ -6087,18 +6093,23 @@ class AdjointGenerator safunccall = Builder2.CreateCall(dfunc, safuncargs); } if (bactive) { - if (call.getArgOperand(1) == Builder2.getInt32(112) || - call.getArgOperand(1) == Builder2.getInt32(113)) { + if (gutils->getNewFromOriginal(call.getArgOperand(1)) == + Builder2.getInt32(112) || + gutils->getNewFromOriginal(call.getArgOperand(1)) == + Builder2.getInt32(113)) { sbatrans = Builder2.getInt32(111); - if (call.getArgOperand(0) == Builder2.getInt32(102)) + if (gutils->getNewFromOriginal(call.getArgOperand(0)) == + Builder2.getInt32(102)) sblda = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), Builder2); else sblda = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), Builder2); - } else if (call.getArgOperand(1) == Builder2.getInt32(111)) { + } else if (gutils->getNewFromOriginal(call.getArgOperand(1)) == + Builder2.getInt32(111)) { sbatrans = Builder2.getInt32(112); - if (call.getArgOperand(0) == Builder2.getInt32(102)) + if (gutils->getNewFromOriginal(call.getArgOperand(0)) == + Builder2.getInt32(102)) sblda = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), Builder2); else @@ -6132,8 +6143,9 @@ class AdjointGenerator scfuncname, Builder2.getVoidTy(), Builder2.getInt32Ty(), call.getArgOperand(6)->getType(), call.getArgOperand(7)->getType(), Builder2.getInt32Ty()); - auto clen = - Builder2.CreateMul(call.getArgOperand(3), call.getArgOperand(4)); + auto clen = Builder2.CreateMul( + gutils->getNewFromOriginal(call.getArgOperand(3)), + gutils->getNewFromOriginal(call.getArgOperand(4))); SmallVector scfuncargs = { clen, lookup(gutils->getNewFromOriginal(call.getArgOperand(11)),