Skip to content

Commit

Permalink
[RTL] Add "channel" RTL parameter type
Browse files Browse the repository at this point in the history
Our new type system allows us to define arbitrarily complex dataflow
channels whose complexity need to be expressable in our backend flow.

This commit adds initial support for this through the new
`RTLChannelType` RTL parameter type ("channel" in JSON config files).
At the moment, this RTL parameter type can be constrained in its data
signal's width, number of extra signals, number of extra downstream
signals, and number of extra upstream signals. More fine-grained
cosntraints can easily be added in the future. A new unit test checks
for correct deserialization of this type.
  • Loading branch information
lucas-rami committed Aug 7, 2024
1 parent e6b823f commit 4e1c39f
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 56 deletions.
31 changes: 17 additions & 14 deletions include/dynamatic/Support/JSON/JSON.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "dynamatic/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/JSON.h"
#include <set>

namespace llvm {
namespace json {
Expand Down Expand Up @@ -85,15 +86,16 @@ class ObjectDeserializer {
/// receiver object to allow chaining. A missing key is considered a mapping
/// failure.
template <typename T>
ObjectDeserializer &map(StringRef key, T &out) {
ObjectDeserializer &map(const llvm::Twine &key, T &out) {
if (!mapValid || !obj)
return *this;
if (const llvm::json::Value *val = obj->get(key)) {
if (auto [_, newKey] = mappedKeys.insert(key); !newKey) {
path.field(key).report(ERR_DUP_KEY);
std::string keyStr = key.str();
if (const llvm::json::Value *val = obj->get(keyStr)) {
if (auto [_, newKey] = mappedKeys.insert(keyStr); !newKey) {
path.field(keyStr).report(ERR_DUP_KEY);
mapValid = false;
} else {
mapValid = fromJSON(*val, out, path.field(key));
mapValid = fromJSON(*val, out, path.field(keyStr));
}
return *this;
}
Expand All @@ -106,7 +108,7 @@ class ObjectDeserializer {
/// receiver object to allow chaining. A missing key is *not* considered a
/// mapping failure.
template <typename T>
ObjectDeserializer &map(StringRef key, std::optional<T> &out) {
ObjectDeserializer &map(const llvm::Twine &key, std::optional<T> &out) {
return mapOptional(key, out);
}

Expand All @@ -115,29 +117,30 @@ class ObjectDeserializer {
/// receiver object to allow chaining. A missing key is *not* considered a
/// mapping failure.
template <typename T>
ObjectDeserializer &mapOptional(StringRef key, T &out) {
ObjectDeserializer &mapOptional(const llvm::Twine &key, T &out) {
if (!mapValid || !obj)
return *this;
if (const llvm::json::Value *val = obj->get(key)) {
if (auto [_, newKey] = mappedKeys.insert(key); !newKey) {
path.field(key).report(ERR_DUP_KEY);
std::string keyStr = key.str();
if (const llvm::json::Value *val = obj->get(keyStr)) {
if (auto [_, newKey] = mappedKeys.insert(keyStr); !newKey) {
path.field(keyStr).report(ERR_DUP_KEY);
mapValid = false;
return *this;
}
mapValid = fromJSON(*val, out, path.field(key));
mapValid = fromJSON(*val, out, path.field(keyStr));
}
return *this;
}

/// If the key exists, invoke the callback with its value and updated path.
/// Returns the receiver object to allow chaining. A missing key is
/// considered a mapping failure.
ObjectDeserializer &map(StringRef key, const MapFn &fn);
ObjectDeserializer &map(const llvm::Twine &key, const MapFn &fn);

/// If the key exists, invoke the callback with its value and updated path.
/// Returns the receiver object to allow chaining. A missing key is *not*
/// considered a mapping failure.
ObjectDeserializer &mapOptional(StringRef key, const MapFn &fn);
ObjectDeserializer &mapOptional(const llvm::Twine &key, const MapFn &fn);

/// Terminates a sequence of mappings. Returns true if all mappings succeeded
/// and if all keys in the object were mapped (excluding those present in the
Expand All @@ -154,7 +157,7 @@ class ObjectDeserializer {
llvm::json::Path path;

/// The set of keys that have already been mapped.
llvm::DenseSet<StringRef> mappedKeys;
std::set<std::string> mappedKeys;
/// Whether all mappings so far were successful. If `map` is invoked and this
/// is false then the method will not even look for the key in the object.
bool mapValid = true;
Expand Down
37 changes: 37 additions & 0 deletions include/dynamatic/Support/RTL/RTLTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

#include "dynamatic/Support/JSON/JSON.h"
#include "dynamatic/Support/LLVM.h"
#include "mlir/IR/Attributes.h"
#include "llvm/Support/JSON.h"
Expand Down Expand Up @@ -60,6 +61,15 @@ struct UnsignedConstraints : public RTLTypeConstraints {
std::optional<unsigned> ne;

bool verify(mlir::Attribute attr) const override;

/// Attempts to deserialize the unsigned constraints using the provided
/// deserializer. Exepcted key names are prefixed using the provided string.
/// The method does not check for the deserializer's validity.
json::ObjectDeserializer &deserialize(json::ObjectDeserializer &deserial,
StringRef keyPrefix = {});

/// Checks whether that the unsigned value honors the constraints.
bool verify(unsigned value) const;
};

/// ADL-findable LLVM-standard JSON deserializer for unsigned constraints.
Expand All @@ -80,6 +90,24 @@ struct StringConstraints : public RTLTypeConstraints {
bool fromJSON(const llvm::json::Value &value, StringConstraints &cons,
llvm::json::Path path);

/// Channel type constraints.
struct ChannelConstraints : public RTLTypeConstraints {
/// Constraints on the data signal's width.
UnsignedConstraints dataWidth;
/// Constraints on the total number of extra signals.
UnsignedConstraints numExtras;
/// Constraints on the number of extra downstream signals.
UnsignedConstraints numDownstreams;
/// Constraints on the number of extra upstream signals.
UnsignedConstraints numUpstreams;

bool verify(mlir::Attribute attr) const override;
};

/// ADL-findable LLVM-standard JSON deserializer for channel constraints.
bool fromJSON(const llvm::json::Value &value, ChannelConstraints &cons,
llvm::json::Path path);

//===----------------------------------------------------------------------===//
// RTLType and derived types
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -212,6 +240,15 @@ struct RTLStringType : public RTLType::Model<RTLStringType, StringConstraints> {
static std::string serialize(mlir::Attribute attr);
};

/// An RTL parameter representing a channel type, stored in the IR as a
/// `TypeAttr`.
struct RTLChannelType
: public RTLType::Model<RTLChannelType, ChannelConstraints> {
static constexpr llvm::StringLiteral ID = "channel";

static std::string serialize(mlir::Attribute attr);
};

} // namespace dynamatic

#endif // DYNAMATIC_SUPPORT_RTL_RTLTYPES_H
26 changes: 15 additions & 11 deletions lib/Support/JSON/JSON.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
Expand Down Expand Up @@ -192,33 +193,35 @@ ObjectDeserializer::ObjectDeserializer(const llvm::json::Object &obj,
llvm::json::Path path)
: obj(&obj), path(path) {}

ObjectDeserializer &ObjectDeserializer::map(StringRef key, const MapFn &fn) {
ObjectDeserializer &ObjectDeserializer::map(const Twine &key, const MapFn &fn) {
if (!mapValid || !obj)
return *this;
if (const llvm::json::Value *val = obj->get(key)) {
if (auto [_, newKey] = mappedKeys.insert(key); !newKey) {
path.field(key).report(ERR_DUP_KEY);
std::string keyStr = key.str();
if (const llvm::json::Value *val = obj->get(keyStr)) {
if (auto [_, newKey] = mappedKeys.insert(keyStr); !newKey) {
path.field(keyStr).report(ERR_DUP_KEY);
mapValid = false;
} else {
mapValid = fn(*val, path.field(key));
mapValid = fn(*val, path.field(keyStr));
}
return *this;
}
mapValid = false;
return *this;
}

ObjectDeserializer &ObjectDeserializer::mapOptional(StringRef key,
ObjectDeserializer &ObjectDeserializer::mapOptional(const Twine &key,
const MapFn &fn) {
if (!mapValid || !obj)
return *this;
if (const llvm::json::Value *val = obj->get(key)) {
if (auto [_, newKey] = mappedKeys.insert(key); !newKey) {
path.field(key).report(ERR_DUP_KEY);
std::string keyStr = key.str();
if (const llvm::json::Value *val = obj->get(keyStr)) {
if (auto [_, newKey] = mappedKeys.insert(keyStr); !newKey) {
path.field(keyStr).report(ERR_DUP_KEY);
mapValid = false;
return *this;
}
mapValid = fn(*val, path.field(key));
mapValid = fn(*val, path.field(keyStr));
}
return *this;
}
Expand All @@ -228,7 +231,8 @@ bool ObjectDeserializer::exhausted(const DenseSet<StringRef> &allowUnmapped) {
return false;
return llvm::all_of(*obj, [&](auto &keyAndVal) {
std::string key = keyAndVal.first.str();
if (mappedKeys.contains(key) || allowUnmapped.contains(key))
if ((mappedKeys.find(key) != mappedKeys.end()) ||
allowUnmapped.contains(key))
return true;
path.field(key).report("unmapped key in object");
return false;
Expand Down
108 changes: 77 additions & 31 deletions lib/Support/RTL/RTLTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "dynamatic/Support/RTL/RTLTypes.h"
#include "dynamatic/Dialect/Handshake/HandshakeTypes.h"
#include "dynamatic/Support/JSON/JSON.h"
#include "mlir/IR/BuiltinAttributes.h"

Expand All @@ -27,7 +28,7 @@ static const mlir::DenseSet<StringRef> RESERVED_KEYS{"name", KEY_TYPE,
"generic"};

static constexpr StringLiteral ERR_UNKNOWN_TYPE(
R"(unknown parameter type: options are "boolean", "unsigned", or "string")");
R"(unknown parameter type: options are "boolean", "unsigned", "string", or "channel")");

bool RTLType::fromJSON(const ljson::Value &value, ljson::Path path) {
if (typeConcept)
Expand All @@ -39,7 +40,8 @@ bool RTLType::fromJSON(const ljson::Value &value, ljson::Path path) {
path.field(KEY_TYPE).report(ERR_MISSING_VALUE);
return false;
}
if (!allocIf<RTLBooleanType, RTLUnsignedType, RTLStringType>(paramType)) {
if (!allocIf<RTLBooleanType, RTLUnsignedType, RTLStringType, RTLChannelType>(
paramType)) {
path.field(KEY_TYPE).report(ERR_UNKNOWN_TYPE);
return false;
}
Expand Down Expand Up @@ -76,9 +78,10 @@ bool UnsignedConstraints::verify(Attribute attr) const {
IntegerAttr intAttr = dyn_cast_if_present<IntegerAttr>(attr);
if (!intAttr || !intAttr.getType().isUnsignedInteger())
return false;
return verify(intAttr.getUInt());
}

// Check all constraints
unsigned value = intAttr.getUInt();
bool UnsignedConstraints::verify(unsigned value) const {
return (!lb || lb <= value) && (!ub || value <= ub) && (!eq || value == eq) &&
(!ne || value != ne);
}
Expand All @@ -88,50 +91,56 @@ static constexpr llvm::StringLiteral
ERR_ARRAY_FORMAT = "expected array to have [lb, ub] format",
ERR_LB = "lower bound already set", ERR_UB = "upper bound already set";

bool dynamatic::fromJSON(const ljson::Value &value, UnsignedConstraints &cons,
ljson::Path path) {
json::ObjectDeserializer &
UnsignedConstraints::deserialize(json::ObjectDeserializer &deserial,
StringRef keyPrefix) {
auto boundFromJSON = [&](StringLiteral err, const ljson::Value &value,
std::optional<unsigned> &bound,
ljson::Path keyPath) -> bool {
if (bound) {
// The bound may be set by the "range" key or the dedicated bound key,
// make sure there is no conflict
path.report(err);
keyPath.report(err);
return false;
}
return ljson::fromJSON(value, bound, keyPath);
};

return ObjectDeserializer(value, path)
.map("eq", cons.eq)
.map("ne", cons.ne)
.mapOptional("lb",
deserial.map(keyPrefix + "eq", eq)
.map(keyPrefix + "ne", ne)
.mapOptional(keyPrefix + "lb",
[&](auto &val, auto path) {
return boundFromJSON(ERR_LB, val, cons.lb, path);
return boundFromJSON(ERR_LB, val, lb, path);
})
.mapOptional("ub",
.mapOptional(keyPrefix + "ub",
[&](auto &val, auto path) {
return boundFromJSON(ERR_UB, val, cons.ub, path);
return boundFromJSON(ERR_UB, val, ub, path);
})
.mapOptional("range",
[&](auto &val, auto path) {
const ljson::Array *array = val.getAsArray();
if (!array) {
path.report(ERR_EXPECTED_ARRAY);
return false;
}
if (array->size() != 2) {
path.report(ERR_ARRAY_FORMAT);
return false;
}
return boundFromJSON(ERR_LB, (*array)[0], cons.lb, path) &&
boundFromJSON(ERR_UB, (*array)[1], cons.ub, path);
})
.exhausted(RESERVED_KEYS);
.mapOptional(keyPrefix + "range", [&](auto &val, auto path) {
const ljson::Array *array = val.getAsArray();
if (!array) {
path.report(ERR_EXPECTED_ARRAY);
return false;
}
if (array->size() != 2) {
path.report(ERR_ARRAY_FORMAT);
return false;
}
return boundFromJSON(ERR_LB, (*array)[0], lb, path) &&
boundFromJSON(ERR_UB, (*array)[1], ub, path);
});

return deserial;
}

bool dynamatic::fromJSON(const ljson::Value &value, UnsignedConstraints &cons,
ljson::Path path) {
ObjectDeserializer deserial(value, path);
return cons.deserialize(deserial).exhausted(RESERVED_KEYS);
}

std::string RTLUnsignedType::serialize(Attribute attr) {
IntegerAttr intAttr = dyn_cast_if_present<IntegerAttr>(attr);
auto intAttr = dyn_cast_if_present<IntegerAttr>(attr);
if (!intAttr)
return "";
return std::to_string(intAttr.getUInt());
Expand All @@ -153,8 +162,45 @@ bool dynamatic::fromJSON(const ljson::Value &value, StringConstraints &cons,
}

std::string RTLStringType::serialize(Attribute attr) {
StringAttr stringAttr = dyn_cast_if_present<StringAttr>(attr);
auto stringAttr = dyn_cast_if_present<StringAttr>(attr);
if (!stringAttr)
return "";
return stringAttr.str();
}

bool ChannelConstraints::verify(Attribute attr) const {
auto typeAttr = dyn_cast_if_present<TypeAttr>(attr);
if (!typeAttr)
return false;
auto channelType = dyn_cast<handshake::ChannelType>(typeAttr.getValue());
if (!channelType)
return false;

return dataWidth.verify(channelType.getDataBitWidth());
}

bool dynamatic::fromJSON(const ljson::Value &value, ChannelConstraints &cons,
ljson::Path path) {
ObjectDeserializer deserial(value, path);
cons.dataWidth.deserialize(deserial, "data-");
cons.numExtras.deserialize(deserial, "extra-");
cons.numDownstreams.deserialize(deserial, "down-");
cons.numUpstreams.deserialize(deserial, "up-");
return deserial.exhausted(RESERVED_KEYS);
}

std::string RTLChannelType::serialize(Attribute attr) {
auto typeAttr = dyn_cast_if_present<TypeAttr>(attr);
if (!typeAttr)
return "";
auto channelType = dyn_cast<handshake::ChannelType>(typeAttr.getValue());
if (!channelType)
return "";

// Convert the channel type to a string
std::stringstream ss;
ss << channelType.getDataBitWidth();
for (const handshake::ExtraSignal &extra : channelType.getExtraSignals())
ss << "-" << extra.name.str() << "-" << extra.getBitWidth();
return ss.str();
}
Loading

0 comments on commit 4e1c39f

Please sign in to comment.