Skip to content

Commit

Permalink
[HandshakeReshapeChannels] Simplify channels before emission
Browse files Browse the repository at this point in the history
Introduces a new pass meant to be a last Handshake-level pre-processing
step before lowering to HW and eventually to RTL. The pass simplifies
the form of channel-typed values around operations that do not read the
content of extra signals (and potentially the data signal). This relies
on

- eligible operations implementing the `ReshapableChannelsInterface`
  interface to signify to the pass which of their operand channels may
  be reshaped and how, and
- the insertion of `handshake::ReshapeOp` in the IR to go back-and-forth
  between complex and simple channel forms.

New unit tests are added for the pass.
  • Loading branch information
lucas-rami committed Aug 4, 2024
1 parent 041480e commit 2082b20
Show file tree
Hide file tree
Showing 14 changed files with 477 additions and 43 deletions.
12 changes: 9 additions & 3 deletions include/dynamatic/Dialect/Handshake/HandshakeArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Handshake_Arith_Op<string mnemonic, list<Trait> traits = []> :

class Handshake_Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
Handshake_Arith_Op<mnemonic, traits # [
SameOperandsAndResultType,
SameOperandsAndResultType, ReshapableChannelsInterface,
DeclareOpInterfaceMethods<NamedIOInterface, ["getOperandName", "getResultName"]>,
]> {
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
Expand Down Expand Up @@ -55,7 +55,9 @@ class Handshake_Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
}

class Handshake_Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
Handshake_Arith_Op<mnemonic, traits # [SameOperandsAndResultType]> {
Handshake_Arith_Op<mnemonic, traits # [SameOperandsAndResultType,
ReshapableChannelsInterface]
> {
let arguments = (ins FloatChannelType:$operand);
let results = (outs FloatChannelType:$result);

Expand All @@ -65,6 +67,7 @@ class Handshake_Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
class Handshake_Arith_CompareOp<string mnemonic, list<Trait> traits = []> :
Handshake_Arith_Op<mnemonic, traits # [
AllTypesMatch<["lhs", "rhs"]>, SameExtraSignalsInterface,
ReshapableChannelsInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>,
DeclareOpInterfaceMethods<NamedIOInterface, ["getOperandName", "getResultName"]>
]> {
Expand All @@ -86,7 +89,9 @@ class Handshake_Arith_CompareOp<string mnemonic, list<Trait> traits = []> :
}

class Handshake_Arith_IToICastOp<string mnemonic, list<Trait> traits = []> :
Handshake_Arith_Op<mnemonic, traits # [SameExtraSignalsInterface]> {
Handshake_Arith_Op<mnemonic, traits # [SameExtraSignalsInterface,
ReshapableChannelsInterface]
> {
let arguments = (ins IntChannelType:$in);
let results = (outs IntChannelType:$out);

Expand Down Expand Up @@ -242,6 +247,7 @@ def Handshake_OrIOp : Handshake_Arith_IntBinaryOp<"ori", [Commutative]> {

def Handshake_SelectOp : Handshake_Arith_Op<"select", [
AllTypesMatch<["trueValue", "falseValue", "result"]>,
DeclareOpInterfaceMethods<ReshapableChannelsInterface, ["getReshapableChannelType"]>,
DeclareOpInterfaceMethods<SameExtraSignalsInterface, ["getChannelsWithSameExtraSignals"]>,
DeclareOpInterfaceMethods<NamedIOInterface, ["getOperandName", "getResultName"]>,
]> {
Expand Down
11 changes: 8 additions & 3 deletions include/dynamatic/Dialect/Handshake/HandshakeInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_INTERFACES_H

#include "dynamatic/Dialect/Handshake/HandshakeDialect.h"
#include "dynamatic/Dialect/Handshake/HandshakeTypes.h"
#include "dynamatic/Support/LLVM.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
Expand All @@ -32,10 +33,8 @@
namespace dynamatic {
namespace handshake {

class ChannelType;

namespace detail {
/// `SameExtraSignalsInterface` default `getChannelsWithSameExtraSignals`'s
/// `SameExtraSignalsInterface`'s default `getChannelsWithSameExtraSignals`'s
/// function (defined as a free function to avoid instantiating an
/// implementation for every concrete operation type).
SmallVector<mlir::TypedValue<handshake::ChannelType>>
Expand All @@ -46,6 +45,12 @@ getChannelsWithSameExtraSignals(Operation *op);
/// operation type).
LogicalResult verifySameExtraSignalsInterface(
Operation *op, ArrayRef<mlir::TypedValue<ChannelType>> channels);

/// `ReshapableChannelsInterface`'s default `getReshapableChannelType` method
/// implementation (defined as a free function to avoid instantiating an
/// implementation for every concrete operation type).
std::pair<handshake::ChannelType, bool> getReshapableChannelType(Operation *op);

} // namespace detail

class ControlType;
Expand Down
27 changes: 26 additions & 1 deletion include/dynamatic/Dialect/Handshake/HandshakeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def SameExtraSignalsInterface : OpInterface<"SameExtraSignalsInterface"> {
let methods = [
InterfaceMethod<[{
Returns the list of operand/result channels that should have the same
extra signals.
extra signals. By default these are all channel-typed operands/results.
}], "::mlir::SmallVector<::mlir::TypedValue<::dynamatic::handshake::ChannelType>>",
"getChannelsWithSameExtraSignals", (ins), "", [{
return ::dynamatic::handshake::detail::getChannelsWithSameExtraSignals(
Expand All @@ -289,4 +289,29 @@ def SameExtraSignalsInterface : OpInterface<"SameExtraSignalsInterface"> {
}];
}

def ReshapableChannelsInterface : OpInterface<"ReshapableChannelsInterface"> {
let cppNamespace = "::dynamatic::handshake";
let description = [{
Handshake operations which do not care for the content of all extra signals
(and optionally, the data signal) of specific operand channels
(`handshake::ChannelType`) may implement this interface to let the backend
know that it is free to reshape these channel-typed operands. Reshaping
tends to simplify RTL generation.
}];

let methods = [
InterfaceMethod<[{
Returns the operand channel type which can be reshaped, as well as a
boolean indicating whether the channel type's data signal is also
ignored by the operation. The returned type may be `nullptr` if no such
channel type exists. By default, returns the type of the first operand
if it is channel-typed and `false`; otherwise returns `nullptr` and an
unspecified boolean.
}], "::std::pair<::dynamatic::handshake::ChannelType, bool>",
"getReshapableChannelType", (ins), "", [{
return ::dynamatic::handshake::detail::getReshapableChannelType($_op);
}]>
];
}

#endif //DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_INTERFACES
59 changes: 42 additions & 17 deletions include/dynamatic/Dialect/Handshake/HandshakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def InstanceOp : Handshake_Op<"instance", [
}

class Handshake_BufferOp<string mnemonic> : Handshake_Op<mnemonic, [
BufferOpInterface, HasClock, SameOperandsAndResultType
BufferOpInterface, HasClock, SameOperandsAndResultType,
DeclareOpInterfaceMethods<ReshapableChannelsInterface, ["getReshapableChannelType"]>,
]> {
let arguments = (ins HandshakeType:$operand,
ConfinedAttr<I32Attr, [IntMinValue<1>]>:$slots);
Expand All @@ -241,6 +242,12 @@ class Handshake_BufferOp<string mnemonic> : Handshake_Op<mnemonic, [
mlir::OperationState &result) {
return parseBufferOp(parser, result);
}

std::pair<handshake::ChannelType, bool>
$cppClass::getReshapableChannelType() {
return {mlir::dyn_cast<handshake::ChannelType>(getOperand().getType()),
true};
}
}];

let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -273,8 +280,10 @@ def TEHBOp : Handshake_BufferOp<"tehb"> {
}

class Handshake_ForkOp<string mnemonic, list<Trait> traits = []> :
Handshake_Op<mnemonic, traits # [Pure, HasClock, SameOperandsAndResultType]>
{
Handshake_Op<mnemonic, traits # [
Pure, HasClock, SameOperandsAndResultType,
DeclareOpInterfaceMethods<ReshapableChannelsInterface, ["getReshapableChannelType"]>
]> {
let arguments = (ins HandshakeType:$operand);
let results = (outs Variadic<HandshakeType>:$result);

Expand All @@ -290,6 +299,14 @@ class Handshake_ForkOp<string mnemonic, list<Trait> traits = []> :
custom<SingleTypedHandshakeOp>($operand, attr-dict,
type($operand), type($result))
}];

let extraClassDefinition = [{
std::pair<handshake::ChannelType, bool>
$cppClass::getReshapableChannelType() {
return {mlir::dyn_cast<handshake::ChannelType>(getOperand().getType()),
true};
}
}];
}

def ForkOp : Handshake_ForkOp<"fork"> {
Expand Down Expand Up @@ -322,7 +339,8 @@ def LazyForkOp : Handshake_ForkOp<"lazy_fork"> {

def MergeOp : Handshake_Op<"merge", [
Pure, SameOperandsAndResultType,
DeclareOpInterfaceMethods<MergeLikeOpInterface, ["getDataResult"]>
DeclareOpInterfaceMethods<MergeLikeOpInterface, ["getDataResult"]>,
DeclareOpInterfaceMethods<ReshapableChannelsInterface, ["getReshapableChannelType"]>
]> {
let summary = "merge operation";
let description = [{
Expand Down Expand Up @@ -350,7 +368,9 @@ def MuxOp : Handshake_Op<"mux", [
DeclareOpInterfaceMethods<MergeLikeOpInterface, ["getDataResult"]>,
DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>,
DeclareOpInterfaceMethods<ControlInterface, ["isControl"]>,
DeclareOpInterfaceMethods<NamedIOInterface, ["getOperandName"]>
DeclareOpInterfaceMethods<NamedIOInterface, ["getOperandName"]>,
DeclareOpInterfaceMethods<SameExtraSignalsInterface, ["getChannelsWithSameExtraSignals"]>,
DeclareOpInterfaceMethods<ReshapableChannelsInterface, ["getReshapableChannelType"]>
]> {
let summary = "mux operation";
let description = [{
Expand Down Expand Up @@ -383,7 +403,9 @@ def MuxOp : Handshake_Op<"mux", [
def ControlMergeOp : Handshake_Op<"control_merge", [
Pure, HasClock,
DeclareOpInterfaceMethods<MergeLikeOpInterface, ["getDataResult"]>,
DeclareOpInterfaceMethods<NamedIOInterface, ["getResultName"]>
DeclareOpInterfaceMethods<NamedIOInterface, ["getResultName"]>,
DeclareOpInterfaceMethods<SameExtraSignalsInterface, ["getChannelsWithSameExtraSignals"]>,
DeclareOpInterfaceMethods<ReshapableChannelsInterface, ["getReshapableChannelType"]>
]> {
let summary = "control merge operation";
let description = [{
Expand Down Expand Up @@ -419,7 +441,10 @@ def ControlMergeOp : Handshake_Op<"control_merge", [
let hasVerifier = 1;
}

def BranchOp : Handshake_Op<"br", [Pure, SameOperandsAndResultType]> {
def BranchOp : Handshake_Op<"br", [
Pure, SameOperandsAndResultType,
DeclareOpInterfaceMethods<ReshapableChannelsInterface, ["getReshapableChannelType"]>
]> {
let summary = "branch operation";
let description = [{
The branch operation represents an unconditional
Expand All @@ -442,11 +467,11 @@ def BranchOp : Handshake_Op<"br", [Pure, SameOperandsAndResultType]> {
}

def ConditionalBranchOp : Handshake_Op<"cond_br", [
AllTypesMatch<["dataOperand", "trueResult", "falseResult"]>,
DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>,
DeclareOpInterfaceMethods<ControlInterface, ["isControl"]>,
DeclareOpInterfaceMethods<NamedIOInterface, ["getOperandName"]>,
DeclareOpInterfaceMethods<NamedIOInterface, ["getResultName"]>,
AllTypesMatch<["dataOperand", "trueResult", "falseResult"]>
DeclareOpInterfaceMethods<NamedIOInterface, ["getOperandName", "getResultName"]>,
DeclareOpInterfaceMethods<ReshapableChannelsInterface, ["getReshapableChannelType"]>
]> {
let summary = "conditional branch operation";
let description = [{
Expand Down Expand Up @@ -536,7 +561,8 @@ def JoinOp : Handshake_Op<"join", [
let assemblyFormat = "$data attr-dict `:` type($result)";
}

def NotOp : Handshake_Op<"not", [Pure, SameOperandsAndResultType]> {
def NotOp : Handshake_Op<"not", [Pure, SameOperandsAndResultType,
ReshapableChannelsInterface]> {
let summary = "Logical negation";
let description = [{
Bitwise logical negation.
Expand Down Expand Up @@ -1217,8 +1243,7 @@ def BundleOp : Handshake_Op<"bundle", [Pure]> {
}

def UnbundleOp : Handshake_Op<"unbundle", [
Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>
Pure, DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>
]> {
let summary = [{
Unbundles a `handshake::ChannelType` or `handshake::ControlType` into
Expand Down Expand Up @@ -1304,14 +1329,14 @@ def ReshapeOp : Handshake_Op<"reshape", [Pure]> {
bitwidths of all downstream extra signals plus the data signal's bitwidth.
If there are no downstream extra signals, the data type is unchanged. All
upstream extra signals (if any) are merged into a single one named
`COMBINED_UP_NAME` with type `i<Y>`, where Y is the sum of the bitwidths
`MERGED_UP_NAME` with type `i<Y>`, where Y is the sum of the bitwidths
of all upstream extra signals. `ChannelReshapeType::SplitData` performs
the reverse transformation.

- `ChannelReshapeType::MergeExtra` behaves identically to
`ChannelReshapeType::MergeData` for extra upstream signals. However,
all downstream extra signals (if any) are merged into a single one named
`COMBINED_DOWN_NAME` with type `i<Z>`, where Z is the sum of the bitwidths
`MERGED_DOWN_NAME` with type `i<Z>`, where Z is the sum of the bitwidths
of all downstream extra signals. The data type always remain unchanged.
`ChannelReshapeType::SplitExtra` performs the reverse transformation.

Expand Down Expand Up @@ -1355,8 +1380,8 @@ def ReshapeOp : Handshake_Op<"reshape", [Pure]> {

let extraClassDeclaration = [{
static constexpr ::llvm::StringLiteral
COMBINED_DOWN_NAME = ::llvm::StringLiteral("mergedDown"),
COMBINED_UP_NAME = ::llvm::StringLiteral("mergedUp");
MERGED_DOWN_NAME = ::llvm::StringLiteral("mergedDown"),
MERGED_UP_NAME = ::llvm::StringLiteral("mergedUp");
}];
}

Expand Down
28 changes: 28 additions & 0 deletions include/dynamatic/Transforms/HandshakeReshapeChannels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- HandshakeReshapeChannels.h - Reshape channels' signals ---*- C++ -*-===//
//
// Dynamatic is under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the --handshake-reshape-channels pass.
//
//===----------------------------------------------------------------------===//

#ifndef DYNAMATIC_TRANSFORMS_HANDSHAKERESHAPE_CHANNELS_H
#define DYNAMATIC_TRANSFORMS_HANDSHAKERESHAPE_CHANNELS_H

#include "dynamatic/Support/DynamaticPass.h"

namespace dynamatic {

#define GEN_PASS_DECL_HANDSHAKERESHAPECHANNELS
#define GEN_PASS_DEF_HANDSHAKERESHAPECHANNELS
#include "dynamatic/Transforms/Passes.h.inc"

std::unique_ptr<dynamatic::DynamaticPass> createHandshakeReshapeChannels();

} // namespace dynamatic

#endif // DYNAMATIC_TRANSFORMS_HANDSHAKERESHAPE_CHANNELS_H
1 change: 1 addition & 0 deletions include/dynamatic/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "dynamatic/Transforms/HandshakeMinimizeLSQUsage.h"
#include "dynamatic/Transforms/HandshakeOptimizeBitwidths.h"
#include "dynamatic/Transforms/HandshakePrepareForLegacy.h"
#include "dynamatic/Transforms/HandshakeReshapeChannels.h"
#include "dynamatic/Transforms/MarkMemoryDependencies.h"
#include "dynamatic/Transforms/MarkMemoryInterfaces.h"
#include "dynamatic/Transforms/OperationNames.h"
Expand Down
12 changes: 12 additions & 0 deletions include/dynamatic/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ def HandshakeMinimizeCstWidth : DynamaticPass<"handshake-minimize-cst-width",
"If true, allows bitwidth optimization of negative values.">];
}

def HandshakeReshapeChannels : DynamaticPass<"handshake-reshape-channels"> {
let summary = "Reshape channels to simplify RTL generation.";
let description = [{
Relies on the `ReshapableChannelsInterface` (part of the Handshake dialect)
to simplify the shape of channels with extra signals around operations that
ignore the content of all extra signals (and potentially the data signal
too). The pass inserts `handshake::ReshapeOp` operations as needed to
reshape channels back and forth between their original and simplified forms.
}];
let constructor = "dynamatic::createHandshakeReshapeChannels()";
}

def HandshakeSetBufferingProperties :
DynamaticPass<"handshake-set-buffering-properties"> {
let summary = "Attach buffering properties to specifc channels in the IR";
Expand Down
Loading

0 comments on commit 2082b20

Please sign in to comment.