Skip to content

Commit

Permalink
Make GraphOperationInfo constructor take a const GraphOperationInst* …
Browse files Browse the repository at this point in the history
…parameter (#20978)
  • Loading branch information
bgogul committed Dec 4, 2018
1 parent 86f99de commit bcf82dc
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 28 deletions.
6 changes: 3 additions & 3 deletions include/swift/SIL/GraphOperationInfo.h
Expand Up @@ -158,7 +158,7 @@ struct GraphOperationInfo {

private:
/// The instruction being analyzed.
GraphOperationInst *inst;
const GraphOperationInst *inst;

/// The TensorFlow op name, decoded from inst.
StringRef OperationName;
Expand All @@ -168,10 +168,10 @@ struct GraphOperationInfo {
llvm::SmallVector<StructuredArgument, 4> StructuredArguments;

public:
explicit GraphOperationInfo(GraphOperationInst *inst);
explicit GraphOperationInfo(const GraphOperationInst *inst);

/// Get the instruction being analyzed.
GraphOperationInst *getInst() const {
const GraphOperationInst *getInst() const {
return inst;
}

Expand Down
2 changes: 1 addition & 1 deletion include/swift/SIL/SILInstruction.h
Expand Up @@ -7945,7 +7945,7 @@ class GraphOperationInst final
return Attributes;
}

Optional<SymbolicValue> getAttributeNamed(StringRef name);
Optional<SymbolicValue> getAttributeNamed(StringRef name) const;

void setNoClustering(bool noClustering) { NoClustering = noClustering; }
bool getNoClustering() const { return NoClustering; }
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/GraphOperationInfo.cpp
Expand Up @@ -20,7 +20,7 @@ using llvm::StringRef;
using namespace swift;
using namespace tf;

GraphOperationInfo::GraphOperationInfo(GraphOperationInst *inst) : inst(inst) {
GraphOperationInfo::GraphOperationInfo(const GraphOperationInst *inst) : inst(inst) {
PrettyStackTraceSILNode X("decoding graph_op name", inst);

ArrayRef<Operand> remainingOperands = inst->getAllOperands();
Expand Down
2 changes: 1 addition & 1 deletion lib/SIL/SILInstructions.cpp
Expand Up @@ -2577,7 +2577,7 @@ GraphOperationAttribute GraphOperationInst::getAttribute(unsigned i) const {
}


Optional<SymbolicValue> GraphOperationInst::getAttributeNamed(StringRef name) {
Optional<SymbolicValue> GraphOperationInst::getAttributeNamed(StringRef name) const {
for (auto attr : getAttributes())
if (attr.name.is(name))
return attr.value;
Expand Down
18 changes: 11 additions & 7 deletions lib/SILOptimizer/Mandatory/TFDeabstraction.cpp
Expand Up @@ -156,10 +156,11 @@ namespace {
void prepareStackAllocForPromotion(AllocStackInst *alloc);
void propagateSSAValues();
void checkAttributesAndFormGraphOps();
void evaluateAttributesAndDoPacking(
GraphOperationInfo &opInfo,
DenseMap<SILValue, SymbolicValue> &constants,
GraphFunctionDeviceInfo &deviceInfo);
void
evaluateAttributesAndDoPacking(GraphOperationInst *origInst,
GraphOperationInfo &opInfo,
DenseMap<SILValue, SymbolicValue> &constants,
GraphFunctionDeviceInfo &deviceInfo);
void cleanupDeadInstructions();
};
} // end anonymous namespace
Expand Down Expand Up @@ -2012,7 +2013,8 @@ void TFDeabstraction::checkAttributesAndFormGraphOps() {
opInfo.getOperationName() == "tfc.configureGPU" ||
opInfo.getOperationName() == "tfc.configureCPU")
continue;
evaluateAttributesAndDoPacking(opInfo, constants, deviceInfo);
evaluateAttributesAndDoPacking(graphOpInst, opInfo, constants,
deviceInfo);
// evaluateAttributesAndDoPacking deletes inst. So, continue as the rest
// of the loop is irrelevant. (This also avoid memory errors.)
continue;
Expand Down Expand Up @@ -2264,9 +2266,11 @@ static bool collectInnermostTensorFlowDTypes(
/// This deletes the underlying inst in `opInfo` when a GraphOperation is
/// created successfully.
void TFDeabstraction::evaluateAttributesAndDoPacking(
GraphOperationInfo &opInfo, DenseMap<SILValue, SymbolicValue> &constants,
GraphOperationInst *origInst, GraphOperationInfo &opInfo,
DenseMap<SILValue, SymbolicValue> &constants,
GraphFunctionDeviceInfo &deviceInfo) {
auto *origInst = opInfo.getInst();
assert(opInfo.getInst() == origInst &&
"Instruction and GraphOperationInfo don't match.");
auto &context = origInst->getFunction()->getASTContext();
auto &allocator = context.getAllocator();
SILBuilder B(origInst);
Expand Down
24 changes: 12 additions & 12 deletions lib/SILOptimizer/Mandatory/TFLowerGraph.cpp
Expand Up @@ -611,18 +611,18 @@ struct TFGraphFunctionLowering
GLStatus visitGraphOpD2DTensorSendInst(GraphOperationInfo &graphOpInfo);

// Helper functions to add different flavors of send/recv TF ops.
GLStatus addTFRecvOp(SILInstruction *inst, int transferId,
GLStatus addTFRecvOp(const SILInstruction *inst, int transferId,
StringRef srcDevice);
GLStatus addTFSendOp(SILInstruction *inst, int transferId,
GLStatus addTFSendOp(const SILInstruction *inst, int transferId,
StringRef destDevice);
// For the TPU infeed/outfeed related ops, the shape array of the tensor being
// transferred is given by `dims`, `numDims` and `dimPtrs`.
GLStatus addTPUDequeueOp(SILInstruction *inst, bool isInfeed, int transferId,
ArrayRef<int64_t> dims, ArrayRef<int> numDims,
ArrayRef<int64_t *> dimPtrs);
GLStatus addTPUEnqueueOp(SILInstruction *inst, bool isInfeed, int transferId,
ArrayRef<int64_t> dims, ArrayRef<int> numDims,
ArrayRef<int64_t *> dimPtrs);
GLStatus addTPUDequeueOp(const SILInstruction *inst, bool isInfeed,
int transferId, ArrayRef<int64_t> dims,
ArrayRef<int> numDims, ArrayRef<int64_t *> dimPtrs);
GLStatus addTPUEnqueueOp(const SILInstruction *inst, bool isInfeed,
int transferId, ArrayRef<int64_t> dims,
ArrayRef<int> numDims, ArrayRef<int64_t *> dimPtrs);

// For `op` with `opName` under construction, set a function-typed attribute
// with a graph function name derived from `silFuncName` under the following
Expand Down Expand Up @@ -1169,7 +1169,7 @@ GLStatus TFGraphFunctionLowering::visitGraphOpRecvFromHostInst(
return GLStatus::Success;
}

GLStatus TFGraphFunctionLowering::addTFRecvOp(SILInstruction *inst,
GLStatus TFGraphFunctionLowering::addTFRecvOp(const SILInstruction *inst,
int transferId,
StringRef srcDevice) {
auto opName = "tf_recv_" + llvm::itostr(transferId);
Expand Down Expand Up @@ -1197,7 +1197,7 @@ GLStatus TFGraphFunctionLowering::addTFRecvOp(SILInstruction *inst,
return GLStatus::Success;
}

GLStatus TFGraphFunctionLowering::addTPUDequeueOp(SILInstruction *inst,
GLStatus TFGraphFunctionLowering::addTPUDequeueOp(const SILInstruction *inst,
bool isInfeed, int transferId,
ArrayRef<int64_t> dims,
ArrayRef<int> numDims,
Expand Down Expand Up @@ -1298,7 +1298,7 @@ GLStatus TFGraphFunctionLowering::visitGraphOpD2DTensorRecvInst(
}
}

GLStatus TFGraphFunctionLowering::addTFSendOp(SILInstruction *inst,
GLStatus TFGraphFunctionLowering::addTFSendOp(const SILInstruction *inst,
int transferId,
StringRef destDevice) {
auto opName = "tf_send_" + llvm::itostr(transferId);
Expand Down Expand Up @@ -1328,7 +1328,7 @@ GLStatus TFGraphFunctionLowering::addTFSendOp(SILInstruction *inst,
return GLStatus::Success;
}

GLStatus TFGraphFunctionLowering::addTPUEnqueueOp(SILInstruction *inst,
GLStatus TFGraphFunctionLowering::addTPUEnqueueOp(const SILInstruction *inst,
bool isInfeed, int transferId,
ArrayRef<int64_t> dims,
ArrayRef<int> numDims,
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Mandatory/TFUtilities.cpp
Expand Up @@ -240,7 +240,7 @@ SILLocation tf::getUserSourceLocation(SILValue value) {
/// Get the user's source location for the specified instruction. Because it
/// is an instruction, we can apply various heuristics to improve the
/// precision of the returned location information.
SILLocation tf::getUserSourceLocation(SILInstruction *inst) {
SILLocation tf::getUserSourceLocation(const SILInstruction *inst) {
// If we have a struct extract from a type like Int, Float, or Tensor of an
// internal type like Builtin.i64 or TensorHandle, look through it to the
// higher level type, which will have better source location information.
Expand Down
4 changes: 2 additions & 2 deletions lib/SILOptimizer/Mandatory/TFUtilities.h
Expand Up @@ -91,7 +91,7 @@ SubstitutionMap getSingleSubstitutionMapForElementType(Type ty,
ASTContext &ctx);

/// `inst` must have a single result, and return that result value.
static inline SILValue getSingleValueResult(GraphOperationInst *inst) {
static inline SILValue getSingleValueResult(const GraphOperationInst *inst) {
assert(inst->getNumResults() == 1);
return inst->getResults()[0];
}
Expand All @@ -116,7 +116,7 @@ inline SILLocation getUserSourceLocation(SILDebugLocation loc) {
/// instruction, we can apply various heuristics to improve the precision of
/// the returned location information.
SILLocation getUserSourceLocation(SILValue value);
SILLocation getUserSourceLocation(SILInstruction *inst);
SILLocation getUserSourceLocation(const SILInstruction *inst);

//===--------------------------------------------------------------------===//
// Other stuff
Expand Down

0 comments on commit bcf82dc

Please sign in to comment.