Skip to content

Commit

Permalink
[Handshake] Add handshake::ReshapeOp operation
Browse files Browse the repository at this point in the history
This introduces a new reshaping operation to Handshake which can change
the nature of a channel type's signals without altering the actual
payload carried by the channel, essentially performing an internal
rewiring of the various signals making up the channel. This is mostly
meant to be used during a pre-processing pass for the backend to
simplify channels around components that do not know how to handle or
care for extra signals.
  • Loading branch information
lucas-rami committed Jul 11, 2024
1 parent 9fef8a5 commit 6db1062
Show file tree
Hide file tree
Showing 7 changed files with 423 additions and 6 deletions.
81 changes: 79 additions & 2 deletions include/dynamatic/Dialect/Handshake/HandshakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,8 @@ SameOperandsAndResultType

def BundleOp : Handshake_Op<"bundle", [Pure]> {
let summary = [{
Bundles individual signals into a `handshake::ChannelType` or `handshake::ControlType`.
Bundles individual signals into a `handshake::ChannelType` or
`handshake::ControlType`.
}];
let description = [{
Combines individual signals into a channel-like value, producing upstream
Expand Down Expand Up @@ -1216,7 +1217,8 @@ def UnbundleOp : Handshake_Op<"unbundle", [
DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>
]> {
let summary = [{
Unbundles a `handshake::ChannelType` or `handshake::ControlType` into individual signals.
Unbundles a `handshake::ChannelType` or `handshake::ControlType` into
individual signals.
}];
let description = [{
Splits a channel-like value into its individual signals, producing
Expand Down Expand Up @@ -1274,4 +1276,79 @@ def UnbundleOp : Handshake_Op<"unbundle", [
let hasVerifier = 1;
}

def ChannelReshapeTypeAttr : I32EnumAttr<
"ChannelReshapeType", "",
[
I32EnumAttrCase<"MergeData", 0>,
I32EnumAttrCase<"SplitData", 1>,
I32EnumAttrCase<"MergeExtra", 2>,
I32EnumAttrCase<"SplitExtra", 3>
]> {
let cppNamespace = "::dynamatic::handshake";
}

def ReshapeOp : Handshake_Op<"reshape", [Pure]> {
let summary = "Reshapes the individual signals of a `handshake::ChannelType`.";
let description = [{
Reshapes a `ChannelType`'d value to a different form without modifying the
channel's actual payload, only rerouting bits to different extra signals or
to the channel's data signal. There are currently four reshaping types that
match two-by-two to perform reverse reshapings (each reshaping type pair has
a dedicated operation builder).

- `ChannelReshapeType::MergeData` merges all downstream extra signals into
the data signal, which becomes an `i<X>` where X is the sum of the
bitwidths of all downstream extra signals plus the data signal's bitwidth.
If there are no downstream extra signals, the data type is unchanged. All
upstream extra signals (if any) are merged into a single one named
`COMBINED_UP_NAME` with type `i<Y>`, where Y is the sum of the bitwidths
of all upstream extra signals. `ChannelReshapeType::SplitData` performs
the reverse transformation.

- `ChannelReshapeType::MergeExtra` behaves identically to
`ChannelReshapeType::MergeData` for extra upstream signals. However,
all downstream extra signals (if any) are merged into a single one named
`COMBINED_DOWN_NAME` with type `i<Z>`, where Z is the sum of the bitwidths
of all downstream extra signals. The data type always remain unchanged.
`ChannelReshapeType::SplitExtra` performs the reverse transformation.

Example:
```
// Merging into data
%reshaped = reshape [MergeData] %channel :
(!handshake.channel<f32, [down1: i2, up1: i4 (U), up2: i4 (U), down2: i8]>)
-> (!handshake.channel<i42, [mergedUp: i8 (U)]>)

// -----

// Merging into extra
%reshaped = reshape [MergeExtra] %channel :
(!handshake.channel<f32, [down1: i2, up1: i4 (U), up2: i4 (U), down2: i8]>)
-> (!handshake.channel<f32, [mergedDowm: i10, mergedUp: i8 (U)]>)
```
}];

let arguments = (ins ChannelReshapeTypeAttr:$reshapeType, ChannelType:$channel);
let results = (outs ChannelType:$reshaped);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
"::mlir::TypedValue<::dynamatic::handshake::ChannelType>":$channel,
"bool":$mergeDownstreamIntoData)>,
OpBuilder<(ins
"::mlir::TypedValue<::dynamatic::handshake::ChannelType>":$channel,
"bool":$splitDownstreamFromData, "::mlir::Type":$reshapedType)>,
];

let assemblyFormat = "`[` $reshapeType `]` $channel attr-dict `:` functional-type($channel, $reshaped)";
let hasVerifier = 1;

let extraClassDeclaration = [{
static constexpr ::llvm::StringLiteral
COMBINED_DOWN_NAME = ::llvm::StringLiteral("mergedDown"),
COMBINED_UP_NAME = ::llvm::StringLiteral("mergedUp");
}];
}

#endif // DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_OPS_TD
3 changes: 3 additions & 0 deletions include/dynamatic/Dialect/Handshake/HandshakeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ struct ExtraSignal {

/// Constructs from the storage type (should not be used by client code).
ExtraSignal(const Storage &storage);

/// Returns the signal type's bitwidth.
unsigned getBitWidth() const;
};

bool operator==(const ExtraSignal &lhs, const ExtraSignal &rhs);
Expand Down
5 changes: 4 additions & 1 deletion include/dynamatic/Dialect/Handshake/HandshakeTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,14 @@ def ChannelType : Handshake_Type<"Channel", "channel"> {
unsigned getNumUpstreamExtraSignals() const {
return getNumExtraSignals() - getNumDownstreamExtraSignals();
}

/// Returns the data type's bitwidth.
unsigned getDataBitWidth() const;

/// Determines whether a type is supported as the data type or as the type
/// of an extra signal.
static bool isSupportedSignalType(::mlir::Type type) {
return ::mlir::isa<::mlir::IndexType, ::mlir::IntegerType, ::mlir::FloatType>(type);
return type.isIntOrIndexOrFloat();
}
}];
}
Expand Down
226 changes: 226 additions & 0 deletions lib/Dialect/Handshake/HandshakeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
//===----------------------------------------------------------------------===//

#include "dynamatic/Dialect/Handshake/HandshakeOps.h"
#include "dynamatic/Dialect/Handshake/HandshakeDialect.h"
#include "dynamatic/Dialect/Handshake/HandshakeTypes.h"
#include "dynamatic/Support/CFG.h"
#include "dynamatic/Support/LLVM.h"
Expand Down Expand Up @@ -1936,5 +1937,230 @@ LogicalResult UnbundleOp::verify() {
false, [&]() { return emitError(); });
}

//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//

/// Returns the total number of bits needed to carry all extra downstream
/// signals and all extra upstream signals.
static std::pair<unsigned, unsigned> getTotalExtraWidths(ChannelType type) {
unsigned totalDownWidth = 0, totalUpWidth = 0;
for (const ExtraSignal &extra : type.getExtraSignals()) {
if (extra.downstream)
totalDownWidth += extra.getBitWidth();
else
totalUpWidth += extra.getBitWidth();
}
return {totalDownWidth, totalUpWidth};
}

void ReshapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
TypedValue<ChannelType> channel,
bool mergeDownstreamIntoData) {
// Set the operands/attributes
MLIRContext *ctx = odsBuilder.getContext();
ChannelReshapeType reshapeType = mergeDownstreamIntoData
? ChannelReshapeType::MergeData
: ChannelReshapeType::MergeExtra;
odsState.addAttribute("reshapeType",
ChannelReshapeTypeAttr::get(ctx, reshapeType));
odsState.addOperands(channel);

// All extra signals will be combined into at most one downstream combined
// signal and one upstream combined signal, compute their bitwidth
auto [totalDownWidth, totalUpWidth] = getTotalExtraWidths(channel.getType());

// The result channel type can be determined from the input channel type and
// reshaping type
SmallVector<ExtraSignal> extraSignals;
Type dataType = channel.getType().getDataType();
;
if (mergeDownstreamIntoData) {
if (totalDownWidth) {
unsigned dataWidth = channel.getType().getDataBitWidth();
dataType = odsBuilder.getIntegerType(totalDownWidth + dataWidth);
}
} else {
// At most a single downstream extra signal should remain
if (totalDownWidth != 0) {
extraSignals.emplace_back(COMBINED_DOWN_NAME,
odsBuilder.getIntegerType(totalUpWidth), false);
}
}
// At most a single upstream extra signal should remain
if (totalUpWidth != 0) {
extraSignals.emplace_back(COMBINED_UP_NAME,
odsBuilder.getIntegerType(totalUpWidth), true);
}

odsState.addTypes(ChannelType::get(ctx, dataType, extraSignals));
}

void ReshapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
TypedValue<ChannelType> channel,
bool splitDownstreamFromData, Type reshapedType) {
// Set the operands/attributes
MLIRContext *ctx = odsBuilder.getContext();
ChannelReshapeType reshapeType = splitDownstreamFromData
? ChannelReshapeType::SplitData
: ChannelReshapeType::SplitExtra;
odsState.addAttribute("reshapeType",
ChannelReshapeTypeAttr::get(ctx, reshapeType));
odsState.addOperands(channel);

// The compatibility of this type with the input will be checked in the
// operation's verification function
odsState.addTypes(reshapedType);
}

LogicalResult ReshapeOp::verify() {
ChannelType srcType = getChannel().getType(),
dstType = getReshaped().getType();
std::pair<unsigned, unsigned> srcExtraWidths = getTotalExtraWidths(srcType),
dstExtraWidths = getTotalExtraWidths(dstType);

// Fails if the merged type has at least one extra downstream signal.
auto checkNoExtraDownstreamSignal = [&](bool isMerge) -> LogicalResult {
ChannelType type = isMerge ? dstType : srcType;
unsigned numSignals = type.getNumDownstreamExtraSignals();
if (numSignals == 0)
return success();
return emitError() << "too many extra downstream signals in the "
<< (isMerge ? "destination" : "source")
<< " type, expected 0 but got " << numSignals;
};

// Fails if the merged type has more than one extra signal of the specified
// direction, or if the single extra signal of the direction is incompatible
// with the split type.
auto checkMergedExtraSignal = [&](bool isMerge,
bool isDownstream) -> LogicalResult {
ChannelType type = isMerge ? dstType : srcType;
unsigned expectedWidth, numExtra;
StringRef dirStr, combinedName;
if (isDownstream) {
expectedWidth = isMerge ? srcExtraWidths.first : dstExtraWidths.first;
numExtra = type.getNumDownstreamExtraSignals();
dirStr = "downstream";
combinedName = COMBINED_DOWN_NAME;
} else {
expectedWidth = isMerge ? srcExtraWidths.second : dstExtraWidths.second;
numExtra = type.getNumUpstreamExtraSignals();
dirStr = "uptream";
combinedName = COMBINED_UP_NAME;
}

// There must be either 0 or 1 extra signal of the specified direction
if (numExtra == 0)
return success();
if (numExtra > 1) {
return emitError() << "merged channel type should have at most one "
<< dirStr << " signal, but got " << numExtra;
}

// Retrieve the single extra signal of the correct direction, then check its
// name, bitwidth, and type
const ExtraSignal &extraSignal =
*llvm::find_if(type.getExtraSignals(), [&](const ExtraSignal &extra) {
return extra.downstream == isDownstream;
});

if (extraSignal.name != combinedName) {
return emitError() << "invalid name for merged extra " << dirStr
<< " signal, expected '" << combinedName
<< "' but got '" << extraSignal.name << "'";
}
if (extraSignal.getBitWidth() != expectedWidth) {
return emitError() << "invalid bitwidth for merged extra " << dirStr
<< " signal, expected " << expectedWidth << " but got "
<< extraSignal.getBitWidth();
}
if (!isa<IntegerType>(extraSignal.type)) {
return emitError() << "invalid type for merged extra " << dirStr
<< " signal, expected IntegerType but got "
<< extraSignal.type;
}
return success();
};

// Fails if the merged type's data type is incompatible with the split type.
auto checkCompatibleDataType = [&](bool isMerge) -> LogicalResult {
ChannelType splitType, mergeType;
unsigned extraWidth;
if (isMerge) {
splitType = srcType;
mergeType = dstType;
extraWidth = srcExtraWidths.first;
} else {
splitType = dstType;
mergeType = srcType;
extraWidth = dstExtraWidths.first;
}
Type mergedDataType = mergeType.getDataType();

unsigned expectedMergedDataWidth = extraWidth + splitType.getDataBitWidth();
if (expectedMergedDataWidth != mergeType.getDataBitWidth()) {
return emitError() << "invalid merged data type bitwidth, expected "
<< expectedMergedDataWidth << " but got "
<< mergeType.getDataBitWidth();
}
if (extraWidth) {
if (!isa<IntegerType>(mergedDataType)) {
return emitError() << "invalid merged data type, expected merged "
"IntegerType but got "
<< mergedDataType;
}
} else {
if (splitType.getDataType() != mergedDataType) {
return emitError()
<< "invalid destination data type, expected source data type "
<< splitType.getDataType() << " but got " << mergedDataType;
}
}
return success();
};

// Fails if the merged type's and split type's data types are different.
auto checkSameDataType = [&]() -> LogicalResult {
if (srcType.getDataType() != dstType.getDataType()) {
return emitError() << "reshaping in this mode should not change "
"the data type, expected "
<< srcType.getDataType() << " but got "
<< dstType.getDataType();
}
return success();
};

switch (getReshapeType()) {
case ChannelReshapeType::MergeData:
if (failed(checkNoExtraDownstreamSignal(true)) ||
failed(checkMergedExtraSignal(true, false)) ||
failed(checkCompatibleDataType(true)))
return failure();

break;
case ChannelReshapeType::MergeExtra:
if (failed(checkMergedExtraSignal(true, false)) ||
failed(checkMergedExtraSignal(true, true)) ||
failed(checkSameDataType()))
return failure();
break;
case ChannelReshapeType::SplitData:
if (failed(checkNoExtraDownstreamSignal(false)) ||
failed(checkMergedExtraSignal(false, false)) ||
failed(checkCompatibleDataType(false)))
return failure();
break;
case ChannelReshapeType::SplitExtra:
if (failed(checkMergedExtraSignal(false, false)) ||
failed(checkMergedExtraSignal(false, true)) ||
failed(checkSameDataType()))
return failure();
break;
}

return success();
}

#define GET_OP_CLASSES
#include "dynamatic/Dialect/Handshake/Handshake.cpp.inc"
Loading

0 comments on commit 6db1062

Please sign in to comment.