Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ip integration #30

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/scalehls/Dialect/HLSCpp/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,12 @@ def MulOp : HLSCppOp<"mul", [NoSideEffect]> {
let results = (outs AnyType : $output);
}

def IncludeOp : HLSCppOp<"include", [NoSideEffect]> {
let summary = "C include library operation";
let description = [{}];

let arguments = (ins
StrArrayAttr:$libraries);
}

#endif // SCALEHLS_DIALECT_HLSCPP_STRUCTUREOPS_TD
10 changes: 7 additions & 3 deletions include/scalehls/Dialect/HLSCpp/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class HLSCppVisitorBase {
arith::IndexCastOp, arith::UIToFPOp, arith::SIToFPOp,
arith::FPToSIOp, arith::FPToUIOp,
// HLSCpp operations.
AssignOp, CastOp, MulOp, AddOp>([&](auto opNode) -> ResultType {
AssignOp, CastOp, MulOp, AddOp, IncludeOp>([&](auto opNode) -> ResultType {
return thisCast->visitOp(opNode, args...);
})
.Default([&](auto opNode) -> ResultType {
Expand All @@ -72,8 +72,9 @@ class HLSCppVisitorBase {

/// This callback is invoked on any invalid operations.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args) {
op->emitOpError("is unsupported operation.");
abort();
//op->emitOpError("is unsupported operation.");
//abort();
return ResultType();
}

/// This callback is invoked on any operations that are not handled by the
Expand Down Expand Up @@ -193,6 +194,9 @@ class HLSCppVisitorBase {
HANDLE(CastOp);
HANDLE(AddOp);
HANDLE(MulOp);

// HLS C++ library include operation.
HANDLE(IncludeOp);
#undef HANDLE
};
} // namespace scalehls
Expand Down
11 changes: 11 additions & 0 deletions include/scalehls/Dialect/HLSKernel/Interfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,15 @@ def HLSKernelOpInterface : OpInterface<"HLSKernelOpInterface"> {
];
}

def IPOp : HLSKernelOp<"ip", [HLSKernelOpInterface]> {
let summary = "General IP";
let description = [{}];

let arguments = (ins
Variadic<AnyType>:$inputs,
StrAttr:$path,
StrAttr:$name
);
}

#endif // SCALEHLS_DIALECT_HLSKERNEL_INTERFACES_TD
12 changes: 9 additions & 3 deletions include/scalehls/Dialect/HLSKernel/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class HLSKernelVisitorBase {
// CNN operations.
DenseOp, ConvOp, MaxPoolOp, ReluOp, MergeOp, CopyOp,
// BLAS operations.
GemmOp, SymmOp, SyrkOp, Syr2kOp, TrmmOp>(
GemmOp, SymmOp, SyrkOp, Syr2kOp, TrmmOp,
// IP operation.
IPOp>(
[&](auto opNode) -> ResultType {
return thisCast->visitOp(opNode, args...);
})
Expand All @@ -36,8 +38,9 @@ class HLSKernelVisitorBase {

/// This callback is invoked on any invalid operations.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args) {
op->emitOpError("is unsupported operation.");
abort();
//op->emitOpError("is unsupported operation.");
//abort();
return ResultType();
}

/// This callback is invoked on any operations that are not handled by the
Expand Down Expand Up @@ -66,6 +69,9 @@ class HLSKernelVisitorBase {
HANDLE(Syr2kOp);
HANDLE(TrmmOp);

// IP operation.
HANDLE(IPOp);

#undef HANDLE
};
} // namespace scalehls
Expand Down
81 changes: 79 additions & 2 deletions lib/Translation/EmitHLSCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Translation.h"
#include "scalehls/Dialect/HLSCpp/Visitor.h"
#include "scalehls/Dialect/HLSKernel/Visitor.h"
#include "scalehls/InitAllDialects.h"
#include "scalehls/Support/Utils.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
Expand Down Expand Up @@ -225,6 +229,9 @@ class ModuleEmitter : public ScaleHLSEmitterBase {
void emitBinary(Operation *op, const char *syntax);
void emitUnary(Operation *op, const char *syntax);

/// IP operation emitter.
void emitIP(IPOp op);

/// Special operation emitters.
void emitCall(CallOp op);
void emitSelect(SelectOp op);
Expand Down Expand Up @@ -482,6 +489,20 @@ class ExprVisitor : public HLSCppVisitorBase<ExprVisitor, bool> {
};
} // namespace

namespace {
class KernelVisitor : public HLSKernelVisitorBase<KernelVisitor, bool> {
public:
KernelVisitor(ModuleEmitter &emitter) : emitter(emitter) {}

using HLSKernelVisitorBase::visitOp;
/// IP operation.
bool visitOp(IPOp op) { return emitter.emitIP(op), true; }

private:
ModuleEmitter &emitter;
};
} // namespace

bool ExprVisitor::visitOp(arith::CmpFOp op) {
switch (op.getPredicate()) {
case arith::CmpFPredicate::OEQ:
Expand Down Expand Up @@ -1111,6 +1132,47 @@ void ModuleEmitter::emitUnary(Operation *op, const char *syntax) {
emitNestedLoopTail(rank);
}

/// IP operation emitter.
void ModuleEmitter::emitIP(IPOp op) {
// Emit IP source from JSON if IP exists.
std::string errorMessage;
if (auto jsonFile = mlir::openInputFile(op.path(), &errorMessage)) {
if (auto json = llvm::json::parse(jsonFile->getBuffer())) {
if (auto O = json->getAsObject()) {
if (auto source = O->getObject("source")) {
for (auto line : *source->getArray("code")) {
auto l = line.getAsString()->str();
for (size_t idx = 0; idx < source->getArray("params")->size(); idx++) {
auto p = source->getArray("params")->operator[](idx).getAsString()->str();
auto o = getName(op.getOperands()[idx]).str().str();
for (std::size_t pos = 0; l.npos != (pos = l.find(p, pos)); pos += o.length()) {
l.replace(pos, p.length(), o);
}
}

indent();
os << l << "\n";
}
return;
}
}
}
//emitError(op, "IP JSON cannot be parsed.");
}
//emitError(op, "IP cannot be found.");

// Emit a regular function call if IP does not exist.
os << " __IP__" << op.name() << "(";
unsigned argIdx = 0;
for (auto arg : op.getOperands()) {
emitValue(arg);
if (argIdx++ != op.getOperands().size() - 1) {
os << ", ";
}
}
os << ");\n";
}

/// Special operation emitters.
void ModuleEmitter::emitSelect(SelectOp op) {
unsigned rank = emitNestedLoopHead(op.getResult());
Expand Down Expand Up @@ -1342,6 +1404,9 @@ void ModuleEmitter::emitBlock(Block &block) {
if (StmtVisitor(*this).dispatchVisitor(&op))
continue;

if (KernelVisitor(*this).dispatchVisitor(&op))
continue;

emitError(&op, "can't be correctly emitted.");
}
}
Expand Down Expand Up @@ -1580,15 +1645,27 @@ void ModuleEmitter::emitModule(ModuleOp module) {
#include <math.h>
#include <stdint.h>

// Libraries included by user.
)XXX";

for (auto &op : *module.getBody()) {
if (auto include = dyn_cast<IncludeOp>(op)) {
for (auto library : include.libraries()) {
os << "#include <" << library.dyn_cast<StringAttr>().getValue() << ">\n";
}
}
}

os << R"XXX(
using namespace std;

)XXX";

for (auto &op : *module.getBody()) {
if (auto func = dyn_cast<FuncOp>(op))
emitFunction(func);
else
emitError(&op, "is unsupported operation.");
//else
// emitError(&op, "is unsupported operation.");
}
}

Expand Down