-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Handshake] Support for extra channel signals
The `handshake::ChannelType` type can now contain an arbitrary number of named downstream/upstream extra signals, each with a potentially different index, integer, or float type. This change relies on - a custom `TypeParameter` in TableGen, - explicit parsing/printing/verification logic for `handshake::ChannelType`, and - custom data-structures and allocation code for extra signals. New Handshake unit tests check the correctness of the parsing/verification logic for extra channel signals.
- Loading branch information
1 parent
3676223
commit bebfb02
Showing
6 changed files
with
313 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
//===- HandshakeTypes.cpp - Handshake types ---------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Implements the API to express and manipulate custom Handshake types. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "dynamatic/Dialect/Handshake/HandshakeTypes.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/DialectImplementation.h" | ||
#include "mlir/IR/OpImplementation.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include "llvm/ADT/StringRef.h" | ||
#include <set> | ||
|
||
using namespace mlir; | ||
using namespace dynamatic; | ||
using namespace dynamatic::handshake; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ChannelType | ||
//===----------------------------------------------------------------------===// | ||
|
||
constexpr llvm::StringLiteral UPSTREAM_SYMBOL("U"); | ||
|
||
static inline bool isSupportedSignalType(Type type) { | ||
return isa<IndexType, IntegerType, FloatType>(type); | ||
} | ||
|
||
Type ChannelType::parse(AsmParser &odsParser) { | ||
FailureOr<Type> dataType; | ||
FailureOr<SmallVector<ExtraSignal::Storage>> extraSignalsStorage; | ||
|
||
// Parse literal '<' | ||
if (odsParser.parseLess()) | ||
return {}; | ||
|
||
// Parse variable 'dataType' | ||
dataType = FieldParser<Type>::parse(odsParser); | ||
if (failed(dataType)) { | ||
odsParser.emitError(odsParser.getCurrentLocation(), | ||
"failed to parse ChannelType parameter 'dataType' " | ||
"which is to be a `Type`"); | ||
return nullptr; | ||
} | ||
if (!isSupportedSignalType(*dataType)) { | ||
odsParser.emitError( | ||
odsParser.getCurrentLocation(), | ||
"failed to parse ChannelType parameter 'dataType' " | ||
"which must be `IndexType`, `IntegerType`, or `FloatType`"); | ||
return nullptr; | ||
} | ||
|
||
// Parse variable 'extraSignals' | ||
std::set<std::string> extraNames; | ||
extraSignalsStorage = [&]() -> FailureOr<SmallVector<ExtraSignal::Storage>> { | ||
SmallVector<ExtraSignal::Storage> storage; | ||
|
||
if (!odsParser.parseOptionalComma()) { | ||
auto parseSignal = [&]() -> ParseResult { | ||
auto &signal = storage.emplace_back(); | ||
|
||
// Parse name and check for uniqueness | ||
if (odsParser.parseKeywordOrString(&signal.name)) | ||
return failure(); | ||
if (auto [_, newName] = extraNames.insert(signal.name); !newName) { | ||
odsParser.emitError( | ||
odsParser.getCurrentLocation(), | ||
"duplicated extra signal name, signal names must be unique"); | ||
return failure(); | ||
} | ||
|
||
// Parse colon and type and check type legality | ||
if (odsParser.parseColon() || odsParser.parseType(signal.type)) | ||
return failure(); | ||
if (!isSupportedSignalType(signal.type)) { | ||
odsParser.emitError(odsParser.getCurrentLocation(), | ||
"failed to parse extra signal type which must be " | ||
"`IndexType`, `IntegerType`, or `FloatType`"); | ||
return failure(); | ||
} | ||
|
||
// Attempt to parse the optional upstream symbol | ||
if (!odsParser.parseOptionalLParen()) { | ||
std::string upstreamSymbol; | ||
if (odsParser.parseKeywordOrString(&upstreamSymbol) || | ||
upstreamSymbol != UPSTREAM_SYMBOL || odsParser.parseRParen()) | ||
return failure(); | ||
signal.downstream = false; | ||
} | ||
return success(); | ||
}; | ||
|
||
if (odsParser.parseLSquare() || | ||
odsParser.parseCommaSeparatedList(parseSignal) || | ||
odsParser.parseRSquare()) | ||
return failure(); | ||
} | ||
|
||
return storage; | ||
}(); | ||
if (failed(extraSignalsStorage)) { | ||
odsParser.emitError(odsParser.getCurrentLocation(), | ||
"failed to parse ChannelType parameter 'extraSignals' " | ||
"which is to be a `ArrayRef<ExtraSignal>`"); | ||
return nullptr; | ||
} | ||
|
||
// Parse literal '>' | ||
if (odsParser.parseGreater()) | ||
return {}; | ||
|
||
// Convert the element type of the extra signal storage list to its | ||
// non-storage version (these will be uniqued/allocated by ChannelType::get) | ||
SmallVector<ExtraSignal> extraSignals; | ||
for (const ExtraSignal::Storage &signalStorage : *extraSignalsStorage) | ||
extraSignals.emplace_back(signalStorage); | ||
|
||
return ChannelType::get(odsParser.getContext(), *dataType, extraSignals); | ||
} | ||
|
||
void ChannelType::print(AsmPrinter &odsPrinter) const { | ||
odsPrinter << "<"; | ||
odsPrinter.printStrippedAttrOrType(getDataType()); | ||
if (!getExtraSignals().empty()) { | ||
auto printSignal = [&](const ::dynamatic::handshake::ExtraSignal &signal) { | ||
odsPrinter << signal.name << ": " << signal.type; | ||
if (!signal.downstream) | ||
odsPrinter << "(" << UPSTREAM_SYMBOL << ")"; | ||
}; | ||
|
||
// Print all signals enclosed in square brackets | ||
odsPrinter << ", ["; | ||
for (const ::dynamatic::handshake::ExtraSignal &signal : | ||
getExtraSignals().drop_back()) { | ||
printSignal(signal); | ||
odsPrinter << ", "; | ||
} | ||
printSignal(getExtraSignals().back()); | ||
odsPrinter << "]"; | ||
; | ||
} | ||
odsPrinter << ">"; | ||
} | ||
|
||
LogicalResult ChannelType::verify(function_ref<InFlightDiagnostic()> emitError, | ||
Type dataType, | ||
ArrayRef<ExtraSignal> extraSignals) { | ||
if (!isSupportedSignalType(dataType)) { | ||
return emitError() << "expected data type to be `IndexType`, " | ||
"`IntegerType`, or `FloatType`, but got " | ||
<< dataType; | ||
} | ||
|
||
DenseSet<StringRef> names; | ||
for (const ExtraSignal &signal : extraSignals) { | ||
if (auto [_, newName] = names.insert(signal.name); !newName) { | ||
return emitError() << "expected all signal names to be unique but '" | ||
<< signal.name << "' appears more than once"; | ||
} | ||
if (!isSupportedSignalType(signal.type)) { | ||
return emitError() << "expected extra signal type to be `IndexType`, " | ||
"`IntegerType`, or `FloatType`, but " | ||
<< signal.name << "' has type " << signal.type; | ||
} | ||
} | ||
return success(); | ||
} | ||
|
||
ExtraSignal::Storage::Storage(StringRef name, mlir::Type type, bool downstream) | ||
: name(name), type(type), downstream(downstream) {} | ||
|
||
ExtraSignal::ExtraSignal(StringRef name, mlir::Type type, bool downstream) | ||
: name(name), type(type), downstream(downstream) {} | ||
|
||
ExtraSignal::ExtraSignal(const ExtraSignal::Storage &storage) | ||
: name(storage.name), type(storage.type), downstream(storage.downstream) {} | ||
|
||
bool dynamatic::handshake::operator==(const ExtraSignal &lhs, | ||
const ExtraSignal &rhs) { | ||
return lhs.name == rhs.name && lhs.type == rhs.type && | ||
lhs.downstream == rhs.downstream; | ||
} | ||
|
||
llvm::hash_code dynamatic::handshake::hash_value(const ExtraSignal &signal) { | ||
return llvm::hash_combine(signal.name, signal.type, signal.downstream); | ||
} | ||
|
||
#include "dynamatic/Dialect/Handshake/HandshakeTypes.cpp.inc" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// RUN: dynamatic-opt %s --split-input-file --verify-diagnostics | ||
|
||
// expected-error @below {{failed to parse ChannelType parameter 'dataType' which must be `IndexType`, `IntegerType`, or `FloatType`}} | ||
handshake.func private @invalidDataType(%arg0: !handshake.channel<!handshake.control>) -> !handshake.control; | ||
|
||
// ----- | ||
|
||
// expected-error @below {{failed to parse extra signal type which must be `IndexType`, `IntegerType`, or `FloatType`}} | ||
// expected-error @below {{failed to parse ChannelType parameter 'extraSignals' which is to be a `ArrayRef<ExtraSignal>`}} | ||
handshake.func private @invalidExtraType(%arg0: !handshake.channel<i32, [extra: !handshake.control]>) -> !handshake.control; | ||
|
||
// ----- | ||
|
||
// expected-error @below {{duplicated extra signal name, signal names must be unique}} | ||
// expected-error @below {{failed to parse ChannelType parameter 'extraSignals' which is to be a `ArrayRef<ExtraSignal>`}} | ||
handshake.func private @duplicateExtraNames(%arg0: !handshake.channel<i32, [extra: i16, extra: f16]>) -> !handshake.control; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// RUN: dynamatic-opt %s --split-input-file --verify-diagnostics | ||
|
||
handshake.func private @simpleControl(%arg0: !handshake.control) -> !handshake.control | ||
|
||
// ----- | ||
|
||
handshake.func @simpleChannel(%arg0: !handshake.channel<i32>) -> !handshake.control | ||
|
||
// ----- | ||
|
||
handshake.func @simpleChannelWithDownExtra(%arg0: !handshake.channel<i32, [extra: i1]>) -> !handshake.control | ||
|
||
// ----- | ||
|
||
handshake.func @simpleChannelWithUpExtra(%arg0: !handshake.channel<i32, [extra: i1 (U)]>) -> !handshake.control | ||
|
||
// ----- | ||
|
||
handshake.func @simpleChannelWithDownAndUpExtra(%arg0: !handshake.channel<i32, [extraDown: i1, extraUp: i1 (U)]>) -> !handshake.control | ||
|
||
// ----- | ||
|
||
handshake.func @validDataAndExtraTypes(%arg0: !handshake.channel<f32, [idx: index]>) -> !handshake.control |