Skip to content

Commit

Permalink
[Handshake] Support for extra channel signals
Browse files Browse the repository at this point in the history
The `handshake::ChannelType` type can now contain an arbitrary number of
named downstream/upstream extra signals, each with a potentially
different index, integer, or float type. This change relies on

- a custom `TypeParameter` in TableGen,
- explicit parsing/printing/verification logic for
  `handshake::ChannelType`, and
- custom data-structures and allocation code for extra signals.

New Handshake unit tests check the correctness of the
parsing/verification logic for extra channel signals.
  • Loading branch information
lucas-rami committed Jul 8, 2024
1 parent 3676223 commit bebfb02
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 8 deletions.
50 changes: 49 additions & 1 deletion include/dynamatic/Dialect/Handshake/HandshakeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,61 @@
//
//===----------------------------------------------------------------------===//
//
// This file defines the interfaces of the handshake dialect.
// Declares backing data-structures and API to express and manipulate custom
// Handshake types.
//
//===----------------------------------------------------------------------===//

#ifndef DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_TYPES_H
#define DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_TYPES_H

#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"

namespace dynamatic {
namespace handshake {

/// A dataflow channel's extra signal. The signal has a unique (within a
/// channel's context) name, specific MLIR type, and a direction (downstream or
/// upstream).
struct ExtraSignal {

/// Used when creating `handshake::ChannelType` instances. Owns its name
/// instead of referencing it.
struct Storage {
std::string name;
mlir::Type type = nullptr;
bool downstream = true;

Storage() = default;
Storage(llvm::StringRef name, mlir::Type type, bool downstream = true);
};

/// The signal's name.
llvm::StringRef name;
/// The signal's MLIR type.
mlir::Type type;
/// Whether the signal is going downstream or upstream.
bool downstream;

/// Simple member-by-member constructor.
ExtraSignal(llvm::StringRef name, mlir::Type type, bool downstream = true);

/// Constructs from the storage type (should not be used by client code).
ExtraSignal(const Storage &storage);
};

bool operator==(const ExtraSignal &lhs, const ExtraSignal &rhs);
inline bool operator!=(const ExtraSignal &lhs, const ExtraSignal &rhs) {
return !(lhs == rhs);
}

// NOLINTNEXTLINE(readability-identifier-naming)
llvm::hash_code hash_value(const ExtraSignal &signal);

} // namespace handshake
} // namespace dynamatic

#define GET_TYPEDEF_CLASSES
#include "dynamatic/Dialect/Handshake/HandshakeTypes.h.inc"

Expand Down
37 changes: 30 additions & 7 deletions include/dynamatic/Dialect/Handshake/HandshakeTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_TYPES_TD
#define DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_TYPES_TD

include "mlir/IR/AttrTypeBase.td"

/// Base class for types in the Handshake dialect.
class Handshake_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Handshake_Dialect, name, traits> {
Expand All @@ -30,30 +32,51 @@ def ControlType : Handshake_Type<"Control", "control"> {
}];

let parameters = (ins);
let assemblyFormat = "";
}

def ExtraSignals : TypeParameter<
"::llvm::ArrayRef<::dynamatic::handshake::ExtraSignal>",
"An optional array of extra signals for a dataflow channel"> {
let allocator = [{
::llvm::SmallVector<::dynamatic::handshake::ExtraSignal> tmpSignals;
for (const ::dynamatic::handshake::ExtraSignal &signal : $_self) {
::dynamatic::handshake::ExtraSignal& tmp = tmpSignals.emplace_back(signal);
tmp.name = $_allocator.copyInto(tmp.name);
}
}] # "$_dst = $_allocator.copyInto(" # cppType # [{ (tmpSignals));
}];
let cppStorageType = "::llvm::SmallVector<::dynamatic::handshake::ExtraSignal::Storage>";
let convertFromStorage = [{convertExtraSignalsFromStorage($_self)}];
let comparator = cppType # "($_lhs) == " # cppType # "($_rhs)";

let defaultValue = cppType # "()";
}

def ChannelType : Handshake_Type<"Channel", "channel"> {
let summary = "A dataflow channel.";
let summary = "A dataflow channel with optional extra signals.";
let description = [{
Represents a dataflow channel, which is made up of
- a data signal of arbitrary width and type going downstream (in the same
direction as the natural SSA def-use relation's direction),
- a 1-bit valid signal going downstream (in the same direction as the
natural SSA def-use relation's direction), and
natural SSA def-use relation's direction),
- a 1-bit ready signal going upsteam (in the opposite direction as the
natural SSA def-use relation's direction).
natural SSA def-use relation's direction), and
- an optional list of named extra signals of arbitrary width and type which
may go downstream or upstream.
}];

let parameters = (ins "::mlir::Type":$dataType);
let parameters = (ins "::mlir::Type":$dataType, ExtraSignals:$extraSignals);

let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$dataType), [{
return Base::get(dataType.getContext(), dataType);
return ChannelType::get(dataType.getContext(), dataType, {});
}]>
];

let assemblyFormat = "`<` $dataType `>`";
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;

}

#endif // DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_TYPES_TD
1 change: 1 addition & 0 deletions lib/Dialect/Handshake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_dynamatic_dialect_library(DynamaticHandshake
HandshakeOps.cpp
HandshakeDialect.cpp
HandshakeInterfaces.cpp
HandshakeTypes.cpp
MemoryInterfaces.cpp

LINK_LIBS PUBLIC
Expand Down
194 changes: 194 additions & 0 deletions lib/Dialect/Handshake/HandshakeTypes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
//===- HandshakeTypes.cpp - Handshake types ---------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implements the API to express and manipulate custom Handshake types.
//
//===----------------------------------------------------------------------===//

#include "dynamatic/Dialect/Handshake/HandshakeTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include <set>

using namespace mlir;
using namespace dynamatic;
using namespace dynamatic::handshake;

//===----------------------------------------------------------------------===//
// ChannelType
//===----------------------------------------------------------------------===//

constexpr llvm::StringLiteral UPSTREAM_SYMBOL("U");

static inline bool isSupportedSignalType(Type type) {
return isa<IndexType, IntegerType, FloatType>(type);
}

Type ChannelType::parse(AsmParser &odsParser) {
FailureOr<Type> dataType;
FailureOr<SmallVector<ExtraSignal::Storage>> extraSignalsStorage;

// Parse literal '<'
if (odsParser.parseLess())
return {};

// Parse variable 'dataType'
dataType = FieldParser<Type>::parse(odsParser);
if (failed(dataType)) {
odsParser.emitError(odsParser.getCurrentLocation(),
"failed to parse ChannelType parameter 'dataType' "
"which is to be a `Type`");
return nullptr;
}
if (!isSupportedSignalType(*dataType)) {
odsParser.emitError(
odsParser.getCurrentLocation(),
"failed to parse ChannelType parameter 'dataType' "
"which must be `IndexType`, `IntegerType`, or `FloatType`");
return nullptr;
}

// Parse variable 'extraSignals'
std::set<std::string> extraNames;
extraSignalsStorage = [&]() -> FailureOr<SmallVector<ExtraSignal::Storage>> {
SmallVector<ExtraSignal::Storage> storage;

if (!odsParser.parseOptionalComma()) {
auto parseSignal = [&]() -> ParseResult {
auto &signal = storage.emplace_back();

// Parse name and check for uniqueness
if (odsParser.parseKeywordOrString(&signal.name))
return failure();
if (auto [_, newName] = extraNames.insert(signal.name); !newName) {
odsParser.emitError(
odsParser.getCurrentLocation(),
"duplicated extra signal name, signal names must be unique");
return failure();
}

// Parse colon and type and check type legality
if (odsParser.parseColon() || odsParser.parseType(signal.type))
return failure();
if (!isSupportedSignalType(signal.type)) {
odsParser.emitError(odsParser.getCurrentLocation(),
"failed to parse extra signal type which must be "
"`IndexType`, `IntegerType`, or `FloatType`");
return failure();
}

// Attempt to parse the optional upstream symbol
if (!odsParser.parseOptionalLParen()) {
std::string upstreamSymbol;
if (odsParser.parseKeywordOrString(&upstreamSymbol) ||
upstreamSymbol != UPSTREAM_SYMBOL || odsParser.parseRParen())
return failure();
signal.downstream = false;
}
return success();
};

if (odsParser.parseLSquare() ||
odsParser.parseCommaSeparatedList(parseSignal) ||
odsParser.parseRSquare())
return failure();
}

return storage;
}();
if (failed(extraSignalsStorage)) {
odsParser.emitError(odsParser.getCurrentLocation(),
"failed to parse ChannelType parameter 'extraSignals' "
"which is to be a `ArrayRef<ExtraSignal>`");
return nullptr;
}

// Parse literal '>'
if (odsParser.parseGreater())
return {};

// Convert the element type of the extra signal storage list to its
// non-storage version (these will be uniqued/allocated by ChannelType::get)
SmallVector<ExtraSignal> extraSignals;
for (const ExtraSignal::Storage &signalStorage : *extraSignalsStorage)
extraSignals.emplace_back(signalStorage);

return ChannelType::get(odsParser.getContext(), *dataType, extraSignals);
}

void ChannelType::print(AsmPrinter &odsPrinter) const {
odsPrinter << "<";
odsPrinter.printStrippedAttrOrType(getDataType());
if (!getExtraSignals().empty()) {
auto printSignal = [&](const ::dynamatic::handshake::ExtraSignal &signal) {
odsPrinter << signal.name << ": " << signal.type;
if (!signal.downstream)
odsPrinter << "(" << UPSTREAM_SYMBOL << ")";
};

// Print all signals enclosed in square brackets
odsPrinter << ", [";
for (const ::dynamatic::handshake::ExtraSignal &signal :
getExtraSignals().drop_back()) {
printSignal(signal);
odsPrinter << ", ";
}
printSignal(getExtraSignals().back());
odsPrinter << "]";
;
}
odsPrinter << ">";
}

LogicalResult ChannelType::verify(function_ref<InFlightDiagnostic()> emitError,
Type dataType,
ArrayRef<ExtraSignal> extraSignals) {
if (!isSupportedSignalType(dataType)) {
return emitError() << "expected data type to be `IndexType`, "
"`IntegerType`, or `FloatType`, but got "
<< dataType;
}

DenseSet<StringRef> names;
for (const ExtraSignal &signal : extraSignals) {
if (auto [_, newName] = names.insert(signal.name); !newName) {
return emitError() << "expected all signal names to be unique but '"
<< signal.name << "' appears more than once";
}
if (!isSupportedSignalType(signal.type)) {
return emitError() << "expected extra signal type to be `IndexType`, "
"`IntegerType`, or `FloatType`, but "
<< signal.name << "' has type " << signal.type;
}
}
return success();
}

ExtraSignal::Storage::Storage(StringRef name, mlir::Type type, bool downstream)
: name(name), type(type), downstream(downstream) {}

ExtraSignal::ExtraSignal(StringRef name, mlir::Type type, bool downstream)
: name(name), type(type), downstream(downstream) {}

ExtraSignal::ExtraSignal(const ExtraSignal::Storage &storage)
: name(storage.name), type(storage.type), downstream(storage.downstream) {}

bool dynamatic::handshake::operator==(const ExtraSignal &lhs,
const ExtraSignal &rhs) {
return lhs.name == rhs.name && lhs.type == rhs.type &&
lhs.downstream == rhs.downstream;
}

llvm::hash_code dynamatic::handshake::hash_value(const ExtraSignal &signal) {
return llvm::hash_combine(signal.name, signal.type, signal.downstream);
}

#include "dynamatic/Dialect/Handshake/HandshakeTypes.cpp.inc"
16 changes: 16 additions & 0 deletions test/Dialect/Handshake/errors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: dynamatic-opt %s --split-input-file --verify-diagnostics

// expected-error @below {{failed to parse ChannelType parameter 'dataType' which must be `IndexType`, `IntegerType`, or `FloatType`}}
handshake.func private @invalidDataType(%arg0: !handshake.channel<!handshake.control>) -> !handshake.control;

// -----

// expected-error @below {{failed to parse extra signal type which must be `IndexType`, `IntegerType`, or `FloatType`}}
// expected-error @below {{failed to parse ChannelType parameter 'extraSignals' which is to be a `ArrayRef<ExtraSignal>`}}
handshake.func private @invalidExtraType(%arg0: !handshake.channel<i32, [extra: !handshake.control]>) -> !handshake.control;

// -----

// expected-error @below {{duplicated extra signal name, signal names must be unique}}
// expected-error @below {{failed to parse ChannelType parameter 'extraSignals' which is to be a `ArrayRef<ExtraSignal>`}}
handshake.func private @duplicateExtraNames(%arg0: !handshake.channel<i32, [extra: i16, extra: f16]>) -> !handshake.control;
23 changes: 23 additions & 0 deletions test/Dialect/Handshake/types.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: dynamatic-opt %s --split-input-file --verify-diagnostics

handshake.func private @simpleControl(%arg0: !handshake.control) -> !handshake.control

// -----

handshake.func @simpleChannel(%arg0: !handshake.channel<i32>) -> !handshake.control

// -----

handshake.func @simpleChannelWithDownExtra(%arg0: !handshake.channel<i32, [extra: i1]>) -> !handshake.control

// -----

handshake.func @simpleChannelWithUpExtra(%arg0: !handshake.channel<i32, [extra: i1 (U)]>) -> !handshake.control

// -----

handshake.func @simpleChannelWithDownAndUpExtra(%arg0: !handshake.channel<i32, [extraDown: i1, extraUp: i1 (U)]>) -> !handshake.control

// -----

handshake.func @validDataAndExtraTypes(%arg0: !handshake.channel<f32, [idx: index]>) -> !handshake.control

0 comments on commit bebfb02

Please sign in to comment.