Skip to content

Commit

Permalink
Fixed a bug in createHostSend(), where the scalar to send may not be …
Browse files Browse the repository at this point in the history
…produced by a struct_extract inst (#17890)

Fixed a bug in createHostSend(), where the scalar to send may not be produced by
a struct_extract inst. Instead of patterning matching on inst type, we now
pattern match on the data type.

The test case is now failing with a different error related to loop
canonicalization (https://bugs.swift.org/browse/SR-7765), so it's commented out.
  • Loading branch information
Mingsheng Hong committed Jul 12, 2018
1 parent 2fe1865 commit 3d37ae7
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 12 deletions.
75 changes: 63 additions & 12 deletions lib/SILOptimizer/Mandatory/TFPartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2545,6 +2545,8 @@ static SILValue wrapInStruct(SILValue v, NominalTypeDecl *decl, SILBuilder &B,
/// `decl` is an stdlib numeric type represented by a struct wrapping an LLVM
/// builtin type, such as $Bool, $Float and $Int64. An example returned type is
/// $Builtin.Int1, in the case of $Bool as the input.
/// See getStdlibNumericTypeDeclFromBuiltinType() on the reverse type
/// conversion.
static SILType extractBuiltinTypeFromStdlibNumericType(NominalTypeDecl *decl) {
auto type = getSingleElementDeclFieldType(decl);
return SILType::getPrimitiveObjectType(type);
Expand Down Expand Up @@ -2751,6 +2753,48 @@ static T* castWithDebugInfo(SILInstruction *inst) {
return cast<T>(inst);
}

/// Given an input type such as $Builtin.Int1, return the type decl of the
/// corresponding stdlib numeric type, such as $Bool.
/// See extractBuiltinTypeFromStdlibNumericType() on the reverse type
/// conversion.
/// This function only supports element data types that are accelerable by
/// tensorflow (e.g. BuiltinFloatType::IEEE80 and Float80 are not).
static NominalTypeDecl *
getStdlibNumericTypeDeclFromBuiltinType(Type ty, ASTContext &ctx) {
// BuiltinIntegerType doesn't carry sign information, which TensorFlow needs,
// so we can't rely on getting type information from the builtin types
// themselves. For now we'll just use signed types.
if (auto *BII = ty->getAs<BuiltinIntegerType>()) {
switch (BII->getFixedWidth()) {
case 1:
return ctx.getBoolDecl();
case 8:
return ctx.getInt8Decl();
case 16:
return ctx.getInt16Decl();
case 32:
return ctx.getInt32Decl();
case 64:
return ctx.getInt64Decl();
}
}

if (auto *BIF = ty->getAs<BuiltinFloatType>()) {
switch (BIF->getFPKind()) {
case BuiltinFloatType::IEEE32:
return ctx.getFloatDecl();
case BuiltinFloatType::IEEE64:
return ctx.getDoubleDecl();
case BuiltinFloatType::IEEE16:
case BuiltinFloatType::IEEE80:
case BuiltinFloatType::IEEE128:
case BuiltinFloatType::PPC128:
return nullptr;
}
}
return nullptr;
}

// Create a call to runtime API @_swift_tfc_SendTensorHandle() for the host
// program to receive a tensor from a TF-managed Fifo queue.
//
Expand Down Expand Up @@ -2779,23 +2823,27 @@ static SILValue createHostSend(SILBuilder &B, SILLocation loc, SILValue value,

auto &ctx = B.getASTContext();
Type scalarValueTy, tensorValueTy;
if (!isTensorHandle(value->getType().getASTType())) {
if (isTensorHandle(value->getType().getASTType())) {
tensorValueTy = value->getType().getASTType();
scalarValueTy = getTensorHandleElementType(tensorValueTy);
} else {
assert(createScalarTensorFn);
// Here scalar type is something like $Builtin.FPIEEE32 -- convert it to an
// AccelerableByTensorFlow conforming type like Float first, and then create
// a scalar tensor to send that value.
scalarValueTy = value->getType().getASTType();
if (!scalarValueTy->getAs<StructType>()) {
// The value must be defined by a struct_extract like:
// %34 = struct_extract %33 : $Float, #Float._value
//
// In this case we set `value` to the Float operand like %33 above.
auto *SEI =
castWithDebugInfo<StructExtractInst>(value->getDefiningInstruction());
assert(SEI->getFieldNo() == 0);
value = SEI->getOperand();
assert(value->getDefiningInstruction());
auto *typeDecl =
getStdlibNumericTypeDeclFromBuiltinType(scalarValueTy, ctx);
if (!typeDecl) {
value->dump();
llvm_unreachable("Unexpected data type when sending a scalar!");
}
value = wrapInStruct(value, typeDecl, B, loc);
scalarValueTy = value->getType().getASTType();
}
assert(scalarValueTy->getAs<StructType>());
tensorValueTy =
convertElementTypeToTensorValueType(scalarValueTy, ctx).getASTType();

Expand Down Expand Up @@ -2833,9 +2881,6 @@ static SILValue createHostSend(SILBuilder &B, SILLocation loc, SILValue value,
// Finish our read access and free the stack memory.
B.createEndAccess(loc, access, /*aborted*/ false);
B.createDeallocStack(loc, stackAlloc);
} else {
tensorValueTy = value->getType().getASTType();
scalarValueTy = getTensorHandleElementType(tensorValueTy);
}

auto sendFnRef = B.createFunctionRef(loc, sendFn);
Expand Down Expand Up @@ -4241,12 +4286,18 @@ void TFPartition::run() {
// If this function is a building block of larger tensor programs (e.g.
// the ops defined in the TensorFlow module), then don't transform it in
// isolation.
DEBUG(llvm::dbgs() << "Processing SIL function " << hostFn->getName()
<< " in TFPartition::run().\n");
if (!tfc.shouldBePartitioned(hostFn))
return;

DEBUG(llvm::dbgs() << " " << hostFn->getName()
<< " should be partitioned.\n");
TFFunctionPartition partitioner(*hostFn, PM, *tfModule);
if (!partitioner.markFunction())
return; // No tensor ops found in the function.
DEBUG(llvm::dbgs() << " " << hostFn->getName()
<< " contains tensor op(s).\n");

// Check to see if we cannot transform the function but should. In this
// case we emit a compiler error. This is a limitation of the compiler that
Expand Down
11 changes: 11 additions & 0 deletions test/TensorFlow/control_flow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,14 @@ public func foo<T>(_ a: T) {
_hostOp(b)
}
*/

// TODO: enable this for-loop test once we resolve
// https://bugs.swift.org/browse/SR-7765:
/// SESE FIXME: Imperfect loop exits not handled yet!
// public func foo(n: Int32) {
// var a = Tensor<Float>(1.0)
// for _ in 0..<n {
// a += a
// }
// _hostOp(a)
// }

0 comments on commit 3d37ae7

Please sign in to comment.